diff options
41 files changed, 7011 insertions, 0 deletions
diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..6d9d08f --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,2 @@ +chor 的前提下,能不能学到一种有用的、terminal-conditioned 的局部 credit signal,用来替代 Feedback Alignment (FA/DFA) 中固定随机反馈矩阵带来的粗糙 credit assignment。 我们的目标不是立刻追求最高分类精度,而是验证一个新的假设是否可行:与其桥接 hidden state,不如桥接 credit / costate / value field。 也就是说,我们希望学习的对象不是 h_l \mapsto h_{l+1} 或者 h_l \mapsto h_L, 而是一个 terminal-conditioned 的标量 value / desirability 场 V_\phi(h_l, t_l, s), \qquad t_l = l/L, 并把每层的局部信用分配定义为 a_l := \nabla_{h_l} V_\phi(h_l,t_l,s). 这里 s 是来自终端误差的全局调制码。最简单的做法是 e_T = \nabla_{\hat y}\ell(\hat y,y),\qquad s = P^\top \mathrm{sg}(e_T), 其中 \mathrm{sg}[\cdot] 表示 stop-gradient。对于 10 类分类任务,可以先直接取 P=I,所以 s=e_T\in\mathbb R^{10}。 项目分成两个阶段。第一阶段是一个线性二次 residual sanity check,它的作用不是追任务性能,而是验证"credit bridge 学到的 a_l 到底像不像真正的局部梯度"。第二阶段是一个深 residual MLP 主实验,用来比较 BP、DFA、state bridge 和 credit bridge。 整个项目的最重要约束是:训练时不允许在隐藏层使用 BP anchor。 也就是说,中间层不能用 exact backprop 的 hidden gradient 作为监督信号。允许的只有两件事:一是输出层本地 exact gradient,因为输出层本来就直接看到 loss;二是离线 evaluation 时计算 BP hidden gradient 作为诊断指标,但绝对不能用于训练。 请优先按下面的顺序推进:先把线性 sanity check 完成并拿到清晰诊断结果,再做主实验。不要反过来。 ⸻ 一、项目背景与核心假设 标准 BP 在隐藏层使用 \delta_l^{\mathrm{BP}} = J_{l\to L}^\top e_T, 其中 J_{l\to L} 是从层 l 到输出的 Jacobian。FA/DFA 则用一个固定随机反馈矩阵近似这个对象,例如 DFA 用 \delta_l^{\mathrm{DFA}} = B_l e_T. 我们的怀疑是:问题的关键不在于"找一个更好的静态 B_l",而在于学习一个状态相关、样本相关、深度相关的信用场。也就是说,我们想把隐藏层该收到的 top-down signal 看成某种 terminal-conditioned value field 的梯度: a_l = \nabla_{h_l} V_\phi(h_l,t_l,s). 这个想法可以理解成一种无中间 anchor 的 credit bridge。它和单纯预测 hidden state 的 state bridge 是不同对象:state bridge 试图从 h_l 预测 h_L 或中间状态,而 credit bridge 直接学习 loss 对局部状态的敏感度。 我们需要验证两个具体命题: 第一,state bridge 是否不足以产生有用的局部 credit。也就是说,即使它能把 h_L 预测得不错,它算出来的 a_l^{\mathrm{state}} = \nabla_{h_l} \ell(W_{\rm out}\hat h_L,y) 也未必能真的指导局部更新。 第二,credit bridge 在没有中间 BP anchor 的情况下,能否产生比 DFA 更有用的局部 credit。这里"更有用"主要通过局部扰动相关性、nudging test、离线 BP cosine 等诊断指标体现,而不只是最终 test accuracy。 ⸻ 二、你需要实现的方法 你至少要实现下面 4 个方法,并保证所有方法使用尽可能相同的前向架构、优化器族和训练设置。 方法 1:BP baseline 这是上界基线。在这个方法里,允许标准 end-to-end backprop。它的主要作用是给出 accuracy upper bound,并在 evaluation 阶段提供 hidden gradient 参考。这个方法训练时当然可以调用整网 loss.backward()。 方法 2:DFA baseline 这是无 hidden BP anchor 的强基线。设网络有 L 个 residual blocks,每个隐藏状态维度为 d,类别数为 C。对每个 block 采样一个固定随机矩阵 B_l \in \mathbb R^{d\times C}, 并保持训练中不变。输出误差是 e_T = \nabla_{\hat y}\ell(\hat y,y). 则该层使用 a_l^{\mathrm{DFA}} = B_l e_T. block l 的参数更新使用局部 surrogate: \mathcal L_l^{\mathrm{local}} = \langle F_l(h_l;\theta_l),\mathrm{sg}[a_{l+1}^{\mathrm{DFA}}]\rangle. 请注意:DFA 方法里不能因为偷懒而让 loss.backward() 穿过隐藏层。每个 block 的更新必须只依赖自己的局部前向和本层收到的固定反馈。 方法 3:State bridge baseline 这是我们故意加入的"对象错位"基线。它的目的不是追最好效果,而是验证"桥接 state 不等于桥接 credit"。 做法如下。定义一个共享或半共享的预测器 G_\psi(h_l,t_l,s)\to \hat h_L^{(l)}. 它的输入是当前层 hidden state、层位置 t_l=l/L、以及终端调制码 s。它的训练目标是终端状态回归: \mathcal L_{\mathrm{state}} = \sum_{l=0}^{L-1}\|G_\psi(h_l,t_l,s)-\mathrm{sg}[h_L]\|_2^2. 这里建议只用 l<L 的层做监督,因为 l=L 是平凡点。训练 state bridge 时,输入 hidden state 必须是 detach().requires_grad_(True) 的副本,避免 state predictor 的训练反向影响前向网络。 训练好或同步训练 G_\psi 后,定义 state-bridge credit: a_l^{\mathrm{state}} = \nabla_{h_l}\ell(W_{\rm out}G_\psi(h_l,t_l,s),y). 这里 a_l^{\mathrm{state}} 是对 predictor 输入 h_l 的梯度,不是对前向网络参数的梯度。然后像 DFA 一样,用它做 block 的局部 surrogate 更新: \mathcal L_l^{\mathrm{local}} = \langle F_l(h_l;\theta_l),\mathrm{sg}[a_{l+1}^{\mathrm{state}}]\rangle. 方法 4:Credit bridge(主方法) 这是项目最核心的方法。定义一个标量 value 网络 V_\phi(h_l,t_l,s)\in\mathbb R. 它的输入仍然是 hidden state、深度时间 t_l=l/L、终端调制码 s。然后定义该层信用信号为 a_l = \nabla_{h_l}V_\phi(h_l,t_l,s). 这个 a_l 将被送给对应 block 做局部 surrogate 更新。 credit bridge 的训练先从最简单版本开始,只实现两项 loss。 第一项是终端边界项: \mathcal L_{\mathrm{term}} = \Big( V_\phi(h_L,1,s)-\mathrm{sg}[\ell(\hat y,y)] \Big)^2. 第二项是一阶 bridge consistency。设前向 block 是 residual 形式 h_{l+1}=h_l+F_l(h_l;\theta_l). 引入小噪声参考动力学 \tilde h_{l+1}^{(k)} = h_l + F_l(h_l;\theta_l) + \sigma_l \xi_k,\qquad \xi_k\sim\mathcal N(0,I). 再用一个 EMA target network \bar\phi 构造 target: \hat V_l^{\mathrm{tgt}} = -\lambda \log \left( \frac1K\sum_{k=1}^K \exp\Big( -\frac{V_{\bar\phi}(\tilde h_{l+1}^{(k)},t_{l+1},s)}{\lambda} \Big) \right). bridge loss 定义为 \mathcal L_{\mathrm{bridge}} = \sum_{l=0}^{L-1} \Big( V_\phi(h_l,t_l,s)-\mathrm{sg}[\hat V_l^{\mathrm{tgt}}] \Big)^2. 完整的 feedback loss 为 \mathcal L_\phi = \mathcal L_{\mathrm{term}} + \mathcal L_{\mathrm{bridge}}. 在这个最简版本先不要上 FM auxiliary,也先不要上 MaxCal smoothness 正则。原因很现实:如果一开始就把所有二阶项都加进去,工程复杂度会高很多,不利于快速判断方向是否有信号。 当前向网络更新时,block l 的局部 surrogate 用 \mathcal L_l^{\mathrm{local}} = \langle F_l(h_l;\theta_l),\mathrm{sg}[a_{l+1}]\rangle. 输出层权重 W_{\rm out} 可以直接用 CE 的 exact output gradient 更新,这是允许的,因为输出层本地就能看到 loss。 如果最简版本有信号,再做一个第二版:在 toy 和小模型上加一个 FM-style auxiliary,用来让 credit field 在层与层之间更平滑。定义随机中间点 \bar h_{l,\tau} = h_l+\tau F_l(h_l;\theta_l)+\sqrt{\tau(1-\tau)}\sigma_l\epsilon, \qquad \tau\sim U(0,1),\ \epsilon\sim\mathcal N(0,I), 以及插值 target \bar a_{l,\tau}^{\rm tgt} = (1-\tau)\,\mathrm{sg}[a_l] + \tau\,\mathrm{sg}[a_{l+1}]. 然后加上 \mathcal L_{\mathrm{fm}} = \gamma\sum_{l=0}^{L-1} \left\| \nabla_h V_\phi(\bar h_{l,\tau},(l+\tau)/L,s) - \bar a_{l,\tau}^{\rm tgt} \right\|_2^2. 请注意:这一项会引入对 \nabla_h V_\phi 的训练,也就是二阶 autograd。先只在 toy 和小规模主实验上做,不要一上来全量开启。 ⸻ 三、实验阶段 A:线性二次 residual sanity check 这个实验的作用是:在一个 exact costate 可解析的系统里,验证 credit bridge 学出来的 a_l 是否接近真实的局部梯度。这个阶段前向动力学参数全部固定,不训练 forward net,只训练 feedback / bridge 模型。这样可以把"credit 学习是否正确"和"前向训练是否稳定"分离开。 请实现如下系统。设隐藏维度 d=64,终端输出维度 m=10,层数 L=12。对每层采样一个稳定线性映射 M_l = I + A_l, 其中 A_l 是一个随机矩阵,但要缩放到谱范数较小,例如 \|A_l\|_2\le 0.05。前向动力学是 h_{l+1}=M_l h_l + \sigma \xi_l,\qquad \xi_l\sim\mathcal N(0,I),\ \sigma=0.03. 数据用在线生成即可。可以取 h_0 \sim \mathcal N(0,I_d),\qquad y\sim \mathcal N(0,I_m), 终端损失定义为 \Phi(h_L,y)=\frac12\|Ch_L-y\|_2^2, 其中 C\in\mathbb R^{m\times d} 是固定随机矩阵。 这个系统的 exact costate 是解析可得的。终端梯度是 a_L^{\mathrm{exact}} = C^\top(Ch_L-y), 并且递推为 a_l^{\mathrm{exact}} = M_l^\top a_{l+1}^{\mathrm{exact}}. 请用这个 exact costate 做 evaluation,但不要拿它训练 credit bridge。 这个阶段至少比较 3 个方法:DFA、state bridge、credit bridge。BP 在这里不是必须训练的,因为 forward net 固定;你只需要 exact costate 作为评价参考。 这个阶段的主要评价指标如下。 第一,exact costate cosine: \Gamma_l = \mathbb E\Big[ \cos(a_l,a_l^{\mathrm{exact}}) \Big]. 要按层报告,也要报告 across-layer 平均。 第二,局部扰动相关性。对每层采样 M 个随机方向 v_{l,j}\sim \mathcal N(0,I),归一化后取小扰动 \epsilon。预测的一阶 loss 变化是 \Delta_{l,j}^{\mathrm{pred}} = \langle a_l,\epsilon v_{l,j}\rangle. 真实变化是从该层起重新滚动后续动力学得到的终端损失差: \Delta_{l,j}^{\mathrm{true}} = \Phi(h_l+\epsilon v_{l,j})-\Phi(h_l). 请计算每层的 Pearson 相关 \rho_l = \mathrm{corr}\!\left(\Delta_{l,j}^{\mathrm{pred}},\Delta_{l,j}^{\mathrm{true}}\right). 这个指标非常重要,因为它不依赖 BP,只检验 a_l 是否真是 loss 对局部状态的有效线性近似。 第三,nudging test。对每层做 h_l' = h_l - \eta \frac{a_l}{\mathrm{RMS}(a_l)+10^{-6}}, 然后从该层继续滚动动力学,看终端损失是否下降: \Delta_l^{\mathrm{nudge}} = \Phi(h_l')-\Phi(h_l). 一个好的 credit 应该让 \Delta_l^{\mathrm{nudge}}<0,而且优于 random direction 和 DFA。 第四,bridge residual。对 credit bridge,计算 R_l = \left| V_\phi(h_l,t_l,s) + \lambda\log \left( \frac1K\sum_k e^{-V_{\bar\phi}(\tilde h_{l+1}^{(k)},t_{l+1},s)/\lambda} \right) \right|. 这个指标用来检查 bridge recursion 有没有学出来。 这个阶段的目标不是做到完美,而是回答一个问题:在完全无中间 BP anchor 的条件下,credit bridge 至少能不能比 state bridge 更像一个真正的局部梯度对象。如果 toy 上完全没信号,就不要进入主实验;如果 toy 上 \Gamma_l、\rho_l、\Delta_l^{\mathrm{nudge}} 都明显优于 state bridge 和 DFA,再继续。 ⸻ 四、实验阶段 B:深 residual MLP 主实验 主实验使用图像分类。优先数据集是 CIFAR-10。如果 CIFAR-10 调试太慢,可以先用 FashionMNIST 做 smoke test,但最终请把主结果放在 CIFAR-10 上。 前向架构请用一个深 residual MLP,而不是 Transformer。原因是这个阶段的目标是验证 credit 语义,不是挑战大模型训练。推荐配置如下: 输入先展平,再经过一个线性 embedding 到隐藏维度 d=512 或 d=768。总共有 L=12 个 residual blocks。每个 block 用 pre-LayerNorm + 两层 MLP: h_{l+1} = h_l + W_{2,l}\,\mathrm{GELU}(W_{1,l}\,\mathrm{LN}(h_l)). 输出头为 \hat y = W_{\rm out}\,\mathrm{LN}(h_L). 训练损失为标准 cross-entropy。所有方法都用相同的数据增强、batch size、优化器族、weight decay、训练 epoch 数。建议先用 AdamW,batch size 128,训练 100 到 200 epochs。所有方法跑 3 个随机种子。 主实验必须比较这 4 个方法:BP、DFA、state bridge、credit bridge。credit bridge 如果工程可承受,再加一个 credit bridge + FM auxiliary 版本作为附加实验。 这个阶段有两个非常重要的实现约束。 第一,对于非 BP 方法,绝对不要让全局 loss 反向穿过隐藏层。 正确做法是:先用前向网络得到 h_0,\dots,h_L 和终端误差 e_T。然后为 feedback 网络构造 detached hidden copies: \tilde h_l = \mathrm{detach}(h_l),\qquad \tilde h_l.\mathrm{requires\_grad}=True. 用这些 detached copies 训练 state bridge 或 value net。这样 feedback 网络的梯度不会偷偷训练前向网络。 第二,更新每个 block 参数时,也必须只依赖 block 自己的局部前向和来自 feedback 的 credit。具体地,block l 更新时应重新取 \bar h_l = \mathrm{detach}(h_l), 计算 F_l(\bar h_l;\theta_l),然后最小化 \mathcal L_l^{\mathrm{local}} = \langle F_l(\bar h_l;\theta_l), \mathrm{sg}[a_{l+1}] \rangle. 这样 autograd 只会流向 \theta_l,不会流到更早的 block。每个 block 可以有独立 optimizer,也可以共享 optimizer 但要小心逐块 zero_grad() 和逐块 step()。如果你觉得容易出错,我更建议每个 block 用独立 optimizer。 输出头 W_{\rm out} 的更新可以直接用 CE 的 exact output gradient,但更新时请把 h_L detach 掉,避免反向进隐藏层。 ⸻ 五、主实验评估指标 最终 accuracy 当然要报,但它不是唯一核心指标。我们真正想知道的是:credit bridge 产生的局部 credit 是否比 DFA 和 state bridge 更像 loss 对局部状态的真实敏感度。 请至少计算并汇报下面这些指标。 1. Validation accuracy / train loss / test accuracy 这是标准指标,但不要只看它。所有方法都要报 mean±std over 3 seeds。 2. 局部扰动相关性 \rho_l 这是最重要的无-BP 诊断指标。做法如下:对某个 validation batch,取某层 hidden state h_l,采样 M 个随机方向 v_{l,j},比如 M=16 或 32。对每个方向,计算预测 loss 变化: \Delta_{l,j}^{\mathrm{pred}} = \langle a_l, \epsilon v_{l,j}\rangle. 然后把 hidden state 在该层替换为 h_l+\epsilon v_{l,j},只重新跑从该层到输出的 tail,得到真实 loss 变化: \Delta_{l,j}^{\mathrm{true}} = \ell(h_l+\epsilon v_{l,j})-\ell(h_l). 请计算 Pearson correlation: \rho_l = \mathrm{corr}(\Delta_{l,j}^{\mathrm{pred}}, \Delta_{l,j}^{\mathrm{true}}). 这个指标应按层报告随训练进程的变化。一个有用的 credit 应该在多层上保持正相关,并且优于 DFA。 3. Hidden nudging test 对每层定义归一化 nudging: h_l' = h_l - \eta \frac{a_l}{\mathrm{RMS}(a_l)+10^{-6}}. 然后从该层往后重跑 tail,比较 loss 变化: \Delta_l^{\mathrm{nudge}} = \ell(h_l') - \ell(h_l). 好的 credit 应该让 \Delta_l^{\mathrm{nudge}} 为负,并且优于 random direction 和 DFA。请固定一组 \eta 值,例如 \{10^{-3},3\cdot10^{-3},10^{-2}\},都测一下,避免因为步长选择不当导致误判。 4. Offline BP cosine 这个指标在训练中不能用,但 evaluation 可以用。对 validation batch 单独跑一次完整 BP,得到真实 hidden gradient \delta_l^{\mathrm{BP}} = \frac{\partial \ell}{\partial h_l}. 然后计算 \Gamma_l = \mathbb E\left[\cos(a_l,\delta_l^{\mathrm{BP}})\right]. 按层报告,作为辅助诊断。这个指标很有价值,但它是 secondary,因为我们的核心要求是无 hidden BP anchor 训练,而不是 offline 完全不算 BP。 5. Bridge residual 只对 credit bridge 报告,用来判断 value recursion 是否学出来。定义同上。 6. Feature drift 对每个 block 参数,计算 M_l = \frac{\|W_l^{\mathrm{final}}-W_l^{\mathrm{init}}\|_F}{\|W_l^{\mathrm{init}}\|_F}. 这有助于看不同方法是不是都停留在近似 lazy regime。请按层汇报。 7. State bridge 的 terminal prediction quality 请额外报告 state bridge 的 \|G_\psi(h_l,t_l,s)-h_L\|_2^2 或平均 relative error。因为我们想证明的不是 "它什么都学不会",而是 "即使它把终端状态预测得还行,也不一定能产生好的 credit"。 ⸻ 六、实现建议与工程细节 请用 PyTorch 2.x。toy 和小模型请优先用 float32,不要一上来混合精度,因为 credit bridge 的二阶图在 mixed precision 下更容易不稳定。 对于 value 网络 V_\phi,建议先用一个小型共享 MLP:输入是 [\mathrm{LN}(h_l), \mathrm{time\_embed}(t_l), s] 的拼接,隐藏层宽度 256 或 512,2 到 3 层,输出标量。time embedding 可以直接用一个小 MLP 处理标量 t_l,也可以用正弦位置编码。关键不是 fancy,而是稳定。 对于 state predictor G_\psi,建议也是共享 MLP,但输出维度是 d。不要做得太大,否则会喧宾夺主。 \lambda 可以先试 0.1 和 1.0。bridge 中的噪声 \sigma_l 可以先设成常数 0.01 到 0.05。target network EMA 动量建议 0.99 或 0.995。每个 \hat V_l^{\mathrm{tgt}} 的 Monte Carlo 样本数 K 先取 4;toy 可以取 8。 DFA 的随机反馈矩阵 B_l 请用高斯初始化,并缩放到合适量级,避免 credit magnitude 爆炸或过小。建议做一个简单规范化,让每层 DFA credit 的 RMS 大致在同一量级。credit bridge 的 a_l 也建议在送入局部 surrogate 前做一版归一化实验: \tilde a_l = \frac{a_l}{\mathrm{RMS}(a_l)+10^{-6}}. 你可以同时保留 raw 和 normalized 两种局部更新版本,但最少要在 pilot 实验里比较一下,否则可能会被纯数值尺度问题误导。 请特别注意一点:如果你在训练 feedback/value 网络时直接用前向图上的 h_l,那 autograd 可能会把 \mathcal L_\phi 的梯度传回前向网络,这会破坏"无 hidden BP anchor"的设定。因此,训练 feedback/value 网络时必须使用 detached hidden copies。更新 block 参数时也必须使用 detached hidden input。请把这一点写进代码注释,并在 README 中明确说明。 ⸻ 七、推荐的代码组织方式 请把代码做成清晰、可复现实验的结构。建议至少有这些文件: models/residual_mlp.py:定义前向 residual MLP。 models/value_net.py:定义 V_\phi。 models/state_bridge.py:定义 G_\psi。 methods/bp.py:BP baseline 训练逻辑。 methods/dfa.py:DFA baseline 训练逻辑。 methods/state_bridge_train.py:state bridge 训练与局部更新逻辑。 methods/credit_bridge_train.py:credit bridge 训练与局部更新逻辑。 experiments/toy_lq.py:线性二次 sanity check。 experiments/cifar_resmlp.py:主实验。 metrics/credit_metrics.py:实现 \rho_l、nudging、BP cosine、bridge residual。 configs/*.yaml:所有实验配置。 README_experiments.md:如何运行。 report/:最终图表与报告。 所有实验都必须保存 config、seed、git commit hash、训练日志和中间 checkpoint。请不要做"手动调了很多但没记录"的实验。 ⸻ 八、优先级与预期产出 请按下面优先级执行,不要乱序。 第一优先级:完成 toy 线性二次 sanity check,并拿到清晰图表。最少要有每层 \Gamma_l、\rho_l、\Delta_l^{\mathrm{nudge}}、bridge residual 的曲线或表格。需要对比 DFA、state bridge、credit bridge。这个阶段如果 credit bridge 比 state bridge 更像真实局部梯度,这就已经是一个积极信号。 第二优先级:在 CIFAR-10 深 residual MLP 上完成 BP、DFA、state bridge、credit bridge 这 4 个方法的主实验。必须至少做 3 个种子。需要输出 accuracy 曲线、\rho_l 分层曲线、nudging test、offline BP cosine、feature drift。最重要的比较不是"谁 accuracy 最高",而是"谁在无 hidden BP anchor 的情况下产生更有 credit 语义的局部信号"。 第三优先级:如果前两步顺利,再做 credit bridge + FM auxiliary。只在 toy 和较小主实验上尝试即可,不需要一开始全量跑。因为这一步需要二阶 autograd,工程成本更高。 ⸻ 九、成功标准与如何判断是否值得继续 这个项目的成功标准不是"credit bridge 最终精度超过 BP",那不现实。更合理的阶段性成功标准是: 第一,在 toy 线性系统上,credit bridge 的 exact costate cosine 和局部扰动相关性明显高于 state bridge,且一般高于 DFA。只要多数层上 \rho_l>0、\Delta_l^{\mathrm{nudge}}<0,并且优于 state bridge,这条线就有价值。 第二,在主实验上,即使最终 accuracy 还没有明显超过 DFA,只要 credit bridge 在中早层的 \rho_l、nudging 和 offline BP cosine 上 consistently 好于 DFA,就说明它确实更接近正确的 credit assignment 对象。 第三,如果 state bridge 的 terminal state regression 做得还不错,但它的 \rho_l 和 nudging 仍然明显不如 credit bridge,那就很好地支持了我们的核心论点:桥接 state 不足,桥接 credit / value field 更对路。 ⸻ 十、你最终必须提交的内容 你最终至少需要交付以下内容: 1. 一份简洁但完整的 README_experiments.md,说明如何运行 toy 和主实验。 2. 可复现实验代码,配置清晰,能从命令行指定方法、数据集、seed。 3. 一份结果报告,最好是 PDF 或 markdown,包含:方法定义、实现说明、关键图表、主要发现、失败点和下一步建议。 4. 至少以下图表: • toy:每层 exact costate cosine、\rho_l、nudging、bridge residual; • 主实验:train/test accuracy、各层 \rho_l、nudging、offline BP cosine、feature drift; • state bridge 的 terminal prediction error vs credit quality 对照图。 5. 一个最终结论段,明确回答三个问题: • 无 hidden BP anchor 时,credit bridge 是否能学出有用 credit? • 它是否比 state bridge 更符合"信用分配对象"的定义? • 它是否在主要诊断指标上优于 DFA? ⸻ 十一、一些不要做的事情 不要把 offline BP cosine 当成训练目标。 不要在隐藏层偷偷使用 BP 监督。 不要为了追精度把局部性破坏掉。 不要一上来把所有复杂正则都打开。 不要只报 accuracy 而不报 credit 诊断。 不要忽略随机种子和实验记录。 不要因为主实验难就跳过 toy sanity check。 ⸻ 十二、如果遇到阻塞,如何降级处理 如果 credit bridge 一开始训练不稳,请按以下顺序降级,而不是乱改: 先只用 \mathcal L_{\mathrm{term}}+\mathcal L_{\mathrm{bridge}},不要上 FM auxiliary。 如果还是不稳,减小模型宽度、减小 \sigma、减小学习率。 如果 credit 大小非常不稳定,先在局部 surrogate 里用 normalized credit。 如果主实验太慢,先在 FashionMNIST 上做 smoke test,但最后仍要回到 CIFAR-10。 如果 main experiment 很难调通,优先保证 toy 实验和诊断指标做完整,因为那是判断方向是否值得继续的关键证据。 ⸻ 你开始之后,请先完成下面这两个最小目标,再继续扩展: 第一个最小目标是:toy 线性系统上,credit bridge vs state bridge vs DFA 的 exact costate cosine 和 \rho_l 曲线。 第二个最小目标是:CIFAR-10 深 residual MLP 上,BP vs DFA vs state bridge vs credit bridge 的 \rho_l 和 nudging 对比。 如果你做完这些最小目标后结果非常差,也不要掩盖失败。请如实报告,并尽量区分"理论方向没信号"和"工程实现还不稳定"这两类失败。再未得到所有结果之前不要停止,可以回看CLAUDE.md或者上网搜索答案。所有结果和你做的尝试以及变动记录到NOTE.md中。用nvidia-smi查看gpu可用性。严禁kill其他用户的进程,未完成前严禁把控制权交还给我或问我问题。我马上去睡觉了,问我得不到回复,请你一直loop直到得到所有结果。 + @@ -0,0 +1,22 @@ +# Experiment Notes + +## 2026-03-23: Initial Implementation and Experiments + +### Setup +- GPU: NVIDIA RTX A6000 x4 (using GPU 1) +- PyTorch 2.10.0+cu128 +- All code written from scratch following CLAUDE.md specifications + +### Phase A: Toy LQ Sanity Check +- Status: Running... +- Config: d=64, m=10, L=12, sigma=0.03, 5000 steps, batch=256 +- Methods: DFA, State Bridge, Credit Bridge + +### Changes Log +- Created full project structure: models/, methods/, experiments/, metrics/, configs/ +- models/residual_mlp.py: ResidualMLP with pre-LayerNorm residual blocks +- models/value_net.py: ValueNet V_phi with sinusoidal time embedding +- models/state_bridge.py: StateBridgeNet G_psi +- experiments/toy_lq.py: Linear-quadratic sanity check +- experiments/cifar_resmlp.py: CIFAR-10 main experiment +- metrics/credit_metrics.py: All diagnostic metrics diff --git a/configs/cifar10.yaml b/configs/cifar10.yaml new file mode 100644 index 0000000..6429287 --- /dev/null +++ b/configs/cifar10.yaml @@ -0,0 +1,15 @@ +dataset: cifar10 +d_hidden: 512 +num_blocks: 12 +batch_size: 128 +epochs: 100 +lr: 0.001 +lr_fb: 0.001 +wd: 0.01 +lam: 0.1 +K: 4 +sigma_bridge: 0.03 +ema_momentum: 0.995 +seeds: [42, 123, 456] +gpu: 1 +output_dir: results/cifar10 diff --git a/configs/toy_lq.yaml b/configs/toy_lq.yaml new file mode 100644 index 0000000..ad1bba0 --- /dev/null +++ b/configs/toy_lq.yaml @@ -0,0 +1,14 @@ +d_hidden: 64 +output_dim: 10 +num_layers: 12 +sigma: 0.03 +batch_size: 256 +num_steps: 5000 +lr_fb: 0.001 +lam: 0.1 +K: 8 +ema_momentum: 0.995 +sigma_bridge: 0.03 +eval_every: 200 +gpu: 1 +output_dir: results/toy_lq diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/experiments/__init__.py diff --git a/experiments/__pycache__/__init__.cpython-313.pyc b/experiments/__pycache__/__init__.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..5966841 --- /dev/null +++ b/experiments/__pycache__/__init__.cpython-313.pyc diff --git a/experiments/__pycache__/toy_lq.cpython-313.pyc b/experiments/__pycache__/toy_lq.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..d8710a8 --- /dev/null +++ b/experiments/__pycache__/toy_lq.cpython-313.pyc diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py new file mode 100644 index 0000000..1582f6d --- /dev/null +++ b/experiments/cifar_resmlp.py @@ -0,0 +1,775 @@ +""" +Phase B: Deep Residual MLP on CIFAR-10. +Compare BP, DFA, State Bridge, Credit Bridge. + +CRITICAL CONSTRAINT: No hidden BP anchor for non-BP methods. +All block updates use detached hidden states and local surrogates. +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy +import time + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test, + offline_bp_cosine, feature_drift +) + + +def get_data(dataset='cifar10', batch_size=128): + if dataset == 'cifar10': + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + input_dim = 32 * 32 * 3 + num_classes = 10 + elif dataset == 'fashionmnist': + transform_train = transforms.Compose([ + transforms.RandomCrop(28, padding=2), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.2860,), (0.3530,)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.2860,), (0.3530,)), + ]) + trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test) + input_dim = 28 * 28 + num_classes = 10 + else: + raise ValueError(f"Unknown dataset: {dataset}") + + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader, input_dim, num_classes + + +def evaluate(model, test_loader, device): + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + return correct / total + + +# ============================================================================= +# BP Baseline +# ============================================================================= +def train_bp(model, train_loader, test_loader, device, args): + """Standard end-to-end backprop training.""" + optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + loss = F.cross_entropy(logits, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + + scheduler.step() + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + if epoch % 10 == 0 or epoch == 1: + print(f" [BP] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}") + + return log + + +# ============================================================================= +# DFA Baseline +# ============================================================================= +def train_dfa(model, train_loader, test_loader, device, args): + """ + DFA training with fixed random feedback matrices. + Each block updated with local surrogate: L_l = <F_l(h_l), sg[a_{l+1}^DFA]>. + Output head updated with exact CE gradient (h_L detached). + Embedding updated via DFA credit at h_0. + """ + d = model.d_hidden + num_classes = args.num_classes + L = model.num_blocks + + # Fixed random feedback matrices, one per block + Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] + + # Separate optimizers + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + # Forward pass (no grad for hidden states) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + # e_T = softmax(logits) - one_hot(y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 # (batch, num_classes) + + # 1. Update output head: exact CE gradient, h_L detached + hL_det = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # 2. Update each block with DFA local surrogate + for l in range(L): + h_l = hiddens[l].detach() + # DFA credit: a_{l+1} = B_l @ e_T^T -> (d, batch) -> transpose + a_dfa = (e_T @ Bs[l].T).detach() # (batch, d) = (batch, C) @ (C, d) + # Normalize + rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_dfa_norm = a_dfa / rms + # Local surrogate + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_dfa_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() + + # 3. Update embedding with DFA credit at h_0 + a_0_dfa = (e_T @ Bs[0].T).detach() + rms_0 = (a_0_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_0_dfa / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for s in all_schedulers: + s.step() + + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + if epoch % 10 == 0 or epoch == 1: + print(f" [DFA] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}") + + return log, Bs + + +# ============================================================================= +# State Bridge +# ============================================================================= +def train_state_bridge(model, train_loader, test_loader, device, args): + """ + State Bridge: predict terminal h_L from (h_l, t_l, s), derive credit as + a_l = grad_{h_l} CE(W_out * LN(G_psi(h_l, t_l, s)), y). + """ + d = model.d_hidden + num_classes = args.num_classes + L = model.num_blocks + + state_pred = StateBridgeNet( + d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3 + ).to(device) + + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'state_pred_error': []} + + for epoch in range(1, args.epochs + 1): + model.train() + state_pred.train() + total_loss, correct, total = 0, 0, 0 + total_se = 0 + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + hL_det = hiddens[-1].detach() + + # Train state predictor: G_psi(h_l, t_l, s) -> h_L + # Predict the *residual* from h_l to h_L for numerical stability + state_loss = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + pred_hL = state_pred(h_l_det, t_l, s) + # Target: h_L (use normalized MSE for stability) + target = hL_det + target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0) + state_loss = state_loss + (((pred_hL - target) / target_norm) ** 2).sum(dim=-1).mean() + state_loss = state_loss / L + state_opt.zero_grad() + state_loss.backward() + state_opt.step() + total_se += state_loss.item() * batch + + # Compute credits: a_l = grad_{h_l} CE(out_head(LN(G(h_l, t_l, s))), y) + credits = [] + for l in range(L): + h_l_det = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + pred_hL = state_pred(h_l_det, t_l, s) + pred_logits = model.out_head(model.out_ln(pred_hL)) + pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') + a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0] + credits.append(a_l.detach()) + + # Update output head + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # Update blocks + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() + + # Update embedding with credit at layer 0 + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_0 / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for sch in all_schedulers: + sch.step() + + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + se = total_se / total + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + log['state_pred_error'].append(se) + if epoch % 10 == 0 or epoch == 1: + print(f" [SB] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, " + f"test={test_acc:.4f}, state_err={se:.4f}") + + return log, state_pred + + +# ============================================================================= +# Credit Bridge +# ============================================================================= +def train_credit_bridge(model, train_loader, test_loader, device, args): + """ + Credit Bridge: learn V_phi(h_l, t_l, s) -> scalar value. + Credit: a_l = grad_{h_l} V_phi. + Training: terminal boundary + bridge consistency + terminal gradient matching. + The terminal gradient is local (output layer only), NOT hidden BP. + + Uses a warmup phase: first warmup_epochs, only train value net + output head, + then start using credit bridge signals to update blocks. + During warmup, blocks get DFA-style updates as a fallback. + """ + d = model.d_hidden + num_classes = args.num_classes + L = model.num_blocks + warmup_epochs = max(1, args.epochs // 5) # 20% warmup + + value_net = ValueNet( + d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3 + ).to(device) + value_net_ema = create_ema_model(value_net) + + # DFA fallback matrices for warmup + Bs_fallback = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) + for _ in range(L)] + + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + lam = args.lam + K_samples = args.K + sigma_bridge = args.sigma_bridge + ema_momentum = args.ema_momentum + term_grad_weight = args.term_grad_weight + + log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': []} + + print(f" [CB] Warmup phase: {warmup_epochs} epochs (DFA fallback + value net training)") + + for epoch in range(1, args.epochs + 1): + model.train() + value_net.train() + total_loss, correct, total = 0, 0, 0 + total_vloss = 0 + + # Blend factor: 0 during warmup, linearly increases to 1 after warmup + if epoch <= warmup_epochs: + credit_blend = 0.0 + else: + credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + true_loss = F.cross_entropy(logits, y, reduction='none').detach() + + hL_det = hiddens[-1].detach() + + # ---- Train value net (always) ---- + t_L = torch.ones(batch, device=device) + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # Terminal gradient matching + loss_tgrad = torch.tensor(0.0, device=device) + if term_grad_weight > 0: + hL_req = hL_det.clone().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + hL_req2 = hL_det.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req2)) + ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') + a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() + loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() + + # Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + t_l_next = torch.full((batch,), (l + 1) / L, device=device) + V_l = value_net(h_l_det, t_l, s) + + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K_samples): + noise = sigma_bridge * torch.randn_like(h_next_det) + V_next = value_net_ema(h_next_det + noise, t_l_next, s) + log_terms.append(-V_next / lam) + log_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + loss_bridge = loss_bridge / L + + value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad + + value_opt.zero_grad() + value_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + value_opt.step() + update_ema(value_net, value_net_ema, ema_momentum) + total_vloss += value_loss.item() * batch + + # ---- Compute credits ---- + # Credit bridge credits + cb_credits = [] + for l in range(L): + h_l_det = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + V_l = value_net(h_l_det, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] + cb_credits.append(a_l.detach()) + + # DFA fallback credits + dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] + + # Blend credits + credits = [] + for l in range(L): + if credit_blend >= 1.0: + a = cb_credits[l] + elif credit_blend <= 0.0: + a = dfa_credits[l] + else: + # Normalize both before blending + cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms) + credits.append(a) + + # ---- Update output head ---- + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # ---- Update blocks ---- + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() + + # ---- Update embedding ---- + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_0 / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for sch in all_schedulers: + sch.step() + + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + vloss = total_vloss / total + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + log['value_loss'].append(vloss) + if epoch % 10 == 0 or epoch == 1: + phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}" + print(f" [CB] Epoch {epoch} ({phase}): loss={train_loss:.4f}, train={train_acc:.4f}, " + f"test={test_acc:.4f}, vloss={vloss:.6f}") + + return log, value_net, value_net_ema + + +# ============================================================================= +# Diagnostics +# ============================================================================= +def compute_diagnostics(model, method_name, test_loader, device, args, + value_net=None, state_predictor=None, dfa_Bs=None): + """Compute all diagnostic metrics for a trained model.""" + model.eval() + if value_net is not None: + value_net.eval() + if state_predictor is not None: + state_predictor.eval() + + d = model.d_hidden + L = model.num_blocks + num_classes = args.num_classes + + # Get one batch for diagnostics + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + break + + batch = x.size(0) + + # Forward with hidden states, need grad for BP cosine + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L + 1): + hiddens_bp[l].retain_grad() + loss_bp = F.cross_entropy(logits_bp, y) + loss_bp.backward() + bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} + + # Forward again without grad for clean hidden states + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + results = { + 'bp_cosine': [], + 'perturbation_rho': [], + 'nudging': {'0.001': [], '0.003': [], '0.01': []}, + } + + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + + # Get credit + if method_name == 'bp': + a_l = bp_grads[l] + elif method_name == 'dfa': + a_l = (e_T @ dfa_Bs[l].T).detach() + elif method_name == 'state_bridge': + h_l_req = h_l.clone().requires_grad_(True) + pred_hL = state_predictor(h_l_req, t_l, s) + pred_logits = model.out_head(model.out_ln(pred_hL)) + pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') + a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() + elif method_name == 'credit_bridge': + h_l_req = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() + else: + raise ValueError(f"Unknown method: {method_name}") + + # BP cosine + bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) + results['bp_cosine'].append(bp_cos) + + # Forward function for perturbation and nudging + def make_fwd_fn(start_l): + def fwd_fn(h): + with torch.no_grad(): + curr = h + for i in range(start_l, L): + curr = curr + model.blocks[i](curr) + out = model.out_head(model.out_ln(curr)) + return F.cross_entropy(out, y, reduction='none') + return fwd_fn + + fwd_fn = make_fwd_fn(l) + rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) + results['perturbation_rho'].append(rho) + + for eta in [0.001, 0.003, 0.01]: + nud = nudging_test(h_l, a_l, fwd_fn, eta=eta) + results['nudging'][str(eta)].append(nud) + + return results + + +# ============================================================================= +# Main +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + all_results = {} + + for seed in args.seeds: + print(f"\n{'='*60}") + print(f"Seed {seed}") + print(f"{'='*60}") + + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + + train_loader, test_loader, input_dim, num_classes = get_data(args.dataset, args.batch_size) + args.num_classes = num_classes + + seed_results = {} + + # ---- BP ---- + print("\n--- BP ---") + model_bp = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_bp = {n: p.clone().detach() for n, p in model_bp.named_parameters()} + bp_log = train_bp(model_bp, train_loader, test_loader, device, args) + bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args) + bp_drift = feature_drift(init_bp, {n: p.detach() for n, p in model_bp.named_parameters()}) + seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag, 'drift': bp_drift} + print(f" Final test acc: {bp_log['test_acc'][-1]:.4f}") + + # ---- DFA ---- + print("\n--- DFA ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_dfa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_dfa = {n: p.clone().detach() for n, p in model_dfa.named_parameters()} + dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args) + dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs) + dfa_drift = feature_drift(init_dfa, {n: p.detach() for n, p in model_dfa.named_parameters()}) + seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift} + print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}") + + # ---- State Bridge ---- + print("\n--- State Bridge ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_sb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_sb = {n: p.clone().detach() for n, p in model_sb.named_parameters()} + sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args) + sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args, + state_predictor=state_pred) + sb_drift = feature_drift(init_sb, {n: p.detach() for n, p in model_sb.named_parameters()}) + seed_results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag, 'drift': sb_drift} + print(f" Final test acc: {sb_log['test_acc'][-1]:.4f}") + + # ---- Credit Bridge ---- + print("\n--- Credit Bridge ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_cb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_cb = {n: p.clone().detach() for n, p in model_cb.named_parameters()} + cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args) + cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args, + value_net=vnet) + cb_drift = feature_drift(init_cb, {n: p.detach() for n, p in model_cb.named_parameters()}) + seed_results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag, 'drift': cb_drift} + print(f" Final test acc: {cb_log['test_acc'][-1]:.4f}") + + all_results[seed] = seed_results + + # Save + def serialize(obj): + if isinstance(obj, dict): + return {str(k): serialize(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [serialize(v) for v in obj] + elif isinstance(obj, (np.floating, np.integer)): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, torch.Tensor): + return obj.cpu().numpy().tolist() + return obj + + save_data = serialize(all_results) + save_data['config'] = serialize(vars(args)) + out_path = os.path.join(args.output_dir, f'results_{args.dataset}.json') + with open(out_path, 'w') as f: + json.dump(save_data, f, indent=2) + print(f"\nAll results saved to {out_path}") + return all_results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='cifar10') + parser.add_argument('--d_hidden', type=int, default=512) + parser.add_argument('--num_blocks', type=int, default=12) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=4) + parser.add_argument('--sigma_bridge', type=float, default=0.05) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--term_grad_weight', type=float, default=1.0) + parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456]) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--output_dir', type=str, default='results/cifar10') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() diff --git a/experiments/plot_results.py b/experiments/plot_results.py new file mode 100644 index 0000000..e3e2754 --- /dev/null +++ b/experiments/plot_results.py @@ -0,0 +1,327 @@ +"""Generate plots for toy LQ and CIFAR-10 experiments.""" +import os +import sys +import json +import argparse +import numpy as np + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + + +def plot_toy_results(results_dir='results/toy_lq', output_dir='report'): + """Plot toy LQ experiment results.""" + os.makedirs(output_dir, exist_ok=True) + + # Collect results across seeds + files = [f for f in os.listdir(results_dir) if f.startswith('toy_lq_seed') and f.endswith('.json')] + if not files: + print(f"No toy results found in {results_dir}") + return + + all_data = [] + for f in sorted(files): + with open(os.path.join(results_dir, f)) as fp: + all_data.append(json.load(fp)) + + # Use the last result for per-layer plots (or average if multiple seeds) + data = all_data[-1] + per_layer = data['final_per_layer'] + log_data = data['log'] + + num_layers = len(per_layer['dfa_costate_cos']) + layers = list(range(num_layers)) + + # 1. Per-layer costate cosine + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(layers, per_layer['dfa_costate_cos'], 'o-', label='DFA', color='blue') + ax.plot(layers, per_layer['state_costate_cos'], 's-', label='State Bridge', color='orange') + ax.plot(layers, per_layer['credit_costate_cos'], '^-', label='Credit Bridge', color='green') + ax.set_xlabel('Layer') + ax.set_ylabel('Cosine Similarity with Exact Costate') + ax.set_title('Exact Costate Cosine (Toy LQ)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.2, 1.05) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_costate_cosine.png'), dpi=150) + plt.close(fig) + + # 2. Per-layer perturbation correlation + num_rho_layers = len(per_layer['dfa_rho']) + rho_layers = list(range(num_rho_layers)) + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(rho_layers, per_layer['dfa_rho'], 'o-', label='DFA', color='blue') + ax.plot(rho_layers, per_layer['state_rho'], 's-', label='State Bridge', color='orange') + ax.plot(rho_layers, per_layer['credit_rho'], '^-', label='Credit Bridge', color='green') + ax.set_xlabel('Layer') + ax.set_ylabel('Perturbation Correlation (rho)') + ax.set_title('Local Perturbation Correlation (Toy LQ)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_perturbation_rho.png'), dpi=150) + plt.close(fig) + + # 3. Per-layer nudging test + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(rho_layers, per_layer['dfa_nudge'], 'o-', label='DFA', color='blue') + ax.plot(rho_layers, per_layer['state_nudge'], 's-', label='State Bridge', color='orange') + ax.plot(rho_layers, per_layer['credit_nudge'], '^-', label='Credit Bridge', color='green') + ax.set_xlabel('Layer') + ax.set_ylabel('Nudge Delta (negative = good)') + ax.set_title('Nudging Test (Toy LQ)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_nudging.png'), dpi=150) + plt.close(fig) + + # 4. Bridge residual over training + if log_data['bridge_residual']: + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(log_data['steps'], log_data['bridge_residual'], '-', color='green') + ax.set_xlabel('Training Step') + ax.set_ylabel('Bridge Residual') + ax.set_title('Bridge Residual Over Training (Toy LQ)') + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150) + plt.close(fig) + + # 5. Training curves (costate cosine over time) + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + for ax, key, title in zip(axes, + ['dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos'], + ['DFA', 'State Bridge', 'Credit Bridge']): + ax.plot(log_data['steps'], log_data[key], '-') + ax.set_xlabel('Training Step') + ax.set_ylabel('Avg Costate Cosine') + ax.set_title(f'{title} - Costate Cosine Over Training') + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_cosine_training.png'), dpi=150) + plt.close(fig) + + # 6. Per-layer bridge residual + if per_layer.get('bridge_residual'): + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + br_layers = list(range(len(per_layer['bridge_residual']))) + ax.plot(br_layers, per_layer['bridge_residual'], '^-', color='green') + ax.set_xlabel('Layer') + ax.set_ylabel('Bridge Residual') + ax.set_title('Per-Layer Bridge Residual (Toy LQ)') + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_bridge_residual_per_layer.png'), dpi=150) + plt.close(fig) + + print(f"Toy LQ plots saved to {output_dir}/") + + +def plot_cifar_results(results_path='results/cifar10/cifar_results_cifar10.json', output_dir='report'): + """Plot CIFAR-10 experiment results.""" + os.makedirs(output_dir, exist_ok=True) + + if not os.path.exists(results_path): + print(f"No CIFAR results found at {results_path}") + return + + with open(results_path) as f: + data = json.load(f) + + config = data.pop('config', {}) + methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + colors = {'bp': 'red', 'dfa': 'blue', 'state_bridge': 'orange', 'credit_bridge': 'green'} + labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} + + seeds = [k for k in data.keys() if k != 'config'] + + # 1. Accuracy curves (mean ± std across seeds) + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + for method in methods: + train_accs = [] + test_accs = [] + for seed in seeds: + if method in data[seed]: + log = data[seed][method]['log'] + train_accs.append(log['train_acc']) + test_accs.append(log['test_acc']) + + if train_accs: + train_arr = np.array(train_accs) + test_arr = np.array(test_accs) + epochs = np.arange(1, train_arr.shape[1] + 1) + + mean_train = train_arr.mean(0) + std_train = train_arr.std(0) + mean_test = test_arr.mean(0) + std_test = test_arr.std(0) + + axes[0].plot(epochs, mean_train, '-', color=colors[method], label=labels[method]) + axes[0].fill_between(epochs, mean_train - std_train, mean_train + std_train, + alpha=0.15, color=colors[method]) + axes[1].plot(epochs, mean_test, '-', color=colors[method], label=labels[method]) + axes[1].fill_between(epochs, mean_test - std_test, mean_test + std_test, + alpha=0.15, color=colors[method]) + + axes[0].set_xlabel('Epoch') + axes[0].set_ylabel('Train Accuracy') + axes[0].set_title('Train Accuracy') + axes[0].legend() + axes[0].grid(True, alpha=0.3) + axes[1].set_xlabel('Epoch') + axes[1].set_ylabel('Test Accuracy') + axes[1].set_title('Test Accuracy') + axes[1].legend() + axes[1].grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_accuracy.png'), dpi=150) + plt.close(fig) + + # 2. Per-layer diagnostics (from last seed) + last_seed = seeds[-1] + + # BP cosine per layer + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + for method in methods: + if method in data[last_seed] and 'diagnostics' in data[last_seed][method]: + diag = data[last_seed][method]['diagnostics'] + if 'bp_cosine' in diag: + layers = list(range(len(diag['bp_cosine']))) + ax.plot(layers, diag['bp_cosine'], 'o-', color=colors[method], label=labels[method]) + ax.set_xlabel('Layer') + ax.set_ylabel('Cosine with BP Gradient') + ax.set_title('Offline BP Cosine (CIFAR-10)') + ax.legend() + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_bp_cosine.png'), dpi=150) + plt.close(fig) + + # Perturbation rho per layer + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + for method in methods: + if method in data[last_seed] and 'diagnostics' in data[last_seed][method]: + diag = data[last_seed][method]['diagnostics'] + if 'perturbation_rho' in diag: + layers = list(range(len(diag['perturbation_rho']))) + ax.plot(layers, diag['perturbation_rho'], 'o-', color=colors[method], label=labels[method]) + ax.set_xlabel('Layer') + ax.set_ylabel('Perturbation Correlation (rho)') + ax.set_title('Local Perturbation Correlation (CIFAR-10)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_perturbation_rho.png'), dpi=150) + plt.close(fig) + + # Nudging test per layer (eta=0.01) + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + for method in methods: + if method in data[last_seed] and 'diagnostics' in data[last_seed][method]: + diag = data[last_seed][method]['diagnostics'] + if 'nudging' in diag and '0.01' in diag['nudging']: + nud = diag['nudging']['0.01'] + layers = list(range(len(nud))) + ax.plot(layers, nud, 'o-', color=colors[method], label=labels[method]) + ax.set_xlabel('Layer') + ax.set_ylabel('Nudge Delta (negative = good)') + ax.set_title('Nudging Test eta=0.01 (CIFAR-10)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_nudging.png'), dpi=150) + plt.close(fig) + + # Feature drift per layer + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + for method in methods: + if method in data[last_seed] and 'drift' in data[last_seed][method]: + drift = data[last_seed][method]['drift'] + # Extract per-block drift (only block weights) + block_drifts = [] + for l in range(12): + key = f'blocks.{l}.w1.weight' + if key in drift: + block_drifts.append(drift[key]) + if block_drifts: + ax.plot(range(len(block_drifts)), block_drifts, 'o-', color=colors[method], label=labels[method]) + ax.set_xlabel('Block') + ax.set_ylabel('Feature Drift (||W_final - W_init||/||W_init||)') + ax.set_title('Feature Drift (CIFAR-10)') + ax.legend() + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_feature_drift.png'), dpi=150) + plt.close(fig) + + print(f"CIFAR-10 plots saved to {output_dir}/") + + +def print_summary_table(results_path='results/cifar10/cifar_results_cifar10.json'): + """Print summary table of results.""" + if not os.path.exists(results_path): + print(f"No results at {results_path}") + return + + with open(results_path) as f: + data = json.load(f) + + config = data.pop('config', {}) + methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} + + seeds = [k for k in data.keys() if k != 'config'] + + print("\n" + "="*80) + print("SUMMARY TABLE") + print("="*80) + print(f"{'Method':<20} {'Test Acc':<15} {'Avg rho':<15} {'Avg Nudge(0.01)':<15} {'Avg BP Cos':<15}") + print("-"*80) + + for method in methods: + test_accs = [] + avg_rhos = [] + avg_nudges = [] + avg_bp_cos = [] + + for seed in seeds: + if method in data[seed]: + log = data[seed][method]['log'] + test_accs.append(log['test_acc'][-1]) + + if 'diagnostics' in data[seed][method]: + diag = data[seed][method]['diagnostics'] + if 'perturbation_rho' in diag: + avg_rhos.append(np.mean(diag['perturbation_rho'])) + if 'nudging' in diag and '0.01' in diag['nudging']: + avg_nudges.append(np.mean(diag['nudging']['0.01'])) + if 'bp_cosine' in diag: + avg_bp_cos.append(np.mean(diag['bp_cosine'])) + + ta = f"{np.mean(test_accs):.4f}±{np.std(test_accs):.4f}" if test_accs else "N/A" + rho = f"{np.mean(avg_rhos):.4f}" if avg_rhos else "N/A" + nud = f"{np.mean(avg_nudges):.4f}" if avg_nudges else "N/A" + bpc = f"{np.mean(avg_bp_cos):.4f}" if avg_bp_cos else "N/A" + + print(f"{labels[method]:<20} {ta:<15} {rho:<15} {nud:<15} {bpc:<15}") + + print("="*80) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--toy_dir', type=str, default='results/toy_lq') + parser.add_argument('--cifar_path', type=str, default='results/cifar10/cifar_results_cifar10.json') + parser.add_argument('--output_dir', type=str, default='report') + args = parser.parse_args() + + plot_toy_results(args.toy_dir, args.output_dir) + plot_cifar_results(args.cifar_path, args.output_dir) + print_summary_table(args.cifar_path) diff --git a/experiments/plot_toy_final.py b/experiments/plot_toy_final.py new file mode 100644 index 0000000..2f7c109 --- /dev/null +++ b/experiments/plot_toy_final.py @@ -0,0 +1,183 @@ +"""Generate final toy LQ experiment plots from v2 results across 3 seeds.""" +import os +import json +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +output_dir = 'report' +os.makedirs(output_dir, exist_ok=True) + +# Load all v2 results with term_grad_weight=1.0, fm=0.0 +seeds = [42, 123, 456] +all_data = [] +for seed in seeds: + path = f'results/toy_lq/toy_lq_v2_seed{seed}_lam0.1_sig0.1_tgw1.0_fm0.0.json' + if os.path.exists(path): + with open(path) as f: + all_data.append(json.load(f)) + +if not all_data: + print("No results found!") + exit() + +# Also load v1 baseline (no term_grad) for comparison +v1_path = 'results/toy_lq/toy_lq_seed42.json' +v1_data = None +if os.path.exists(v1_path): + with open(v1_path) as f: + v1_data = json.load(f) + +# Aggregate final per-layer results across seeds +methods = ['dfa', 'state', 'credit'] +colors = {'dfa': '#2196F3', 'state': '#FF9800', 'credit': '#4CAF50'} +labels = {'dfa': 'DFA', 'state': 'State Bridge', 'credit': 'Credit Bridge'} + +# Per-layer costate cosine +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + +for ax, metric, title, ylabel in zip( + axes, + ['costate_cos', 'rho', 'nudge'], + ['Exact Costate Cosine', 'Perturbation Correlation (ρ)', 'Nudging Test'], + ['Cosine Similarity', 'Pearson Correlation', 'Loss Change (negative=good)'] +): + for method in methods: + key = f'{method}_{metric}' + values_per_seed = [] + for data in all_data: + pl = data['final_per_layer'] + if key in pl: + values_per_seed.append(pl[key]) + + if values_per_seed: + arr = np.array(values_per_seed) + mean = arr.mean(axis=0) + std = arr.std(axis=0) + layers = np.arange(len(mean)) + ax.plot(layers, mean, 'o-', color=colors[method], label=labels[method], markersize=5) + ax.fill_between(layers, mean - std, mean + std, alpha=0.15, color=colors[method]) + + ax.set_xlabel('Layer', fontsize=12) + ax.set_ylabel(ylabel, fontsize=12) + ax.set_title(title, fontsize=13) + ax.legend(fontsize=11) + ax.grid(True, alpha=0.3) + if metric == 'costate_cos': + ax.set_ylim(-0.15, 1.05) + elif metric == 'rho': + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + elif metric == 'nudge': + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + +fig.suptitle('Toy LQ Sanity Check: Per-Layer Diagnostics (3 seeds)', fontsize=14, y=1.02) +fig.tight_layout() +fig.savefig(os.path.join(output_dir, 'toy_per_layer_diagnostics.png'), dpi=150, bbox_inches='tight') +plt.close(fig) +print("Saved toy_per_layer_diagnostics.png") + +# Training curves +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) +metric_keys = [ + ('costate_cos', 'Avg Costate Cosine', 'Cosine Similarity'), + ('rho', 'Avg Perturbation ρ', 'Pearson Correlation'), + ('nudge', 'Avg Nudging', 'Loss Change'), +] + +for ax, (metric, title, ylabel) in zip(axes, metric_keys): + for method in methods: + key = f'{method}_{metric}' + all_curves = [] + for data in all_data: + log = data['log'] + full_key = f'{method}_costate_cos' if metric == 'costate_cos' else f'{method}_{metric}' + if full_key in log: + all_curves.append(np.array(log[full_key])) + + if all_curves: + # All should have same length, use shortest + min_len = min(len(c) for c in all_curves) + arr = np.array([c[:min_len] for c in all_curves]) + steps = np.array(all_data[0]['log']['steps'][:min_len]) + mean = arr.mean(axis=0) + std = arr.std(axis=0) + ax.plot(steps, mean, '-', color=colors[method], label=labels[method]) + ax.fill_between(steps, mean - std, mean + std, alpha=0.15, color=colors[method]) + + ax.set_xlabel('Training Step', fontsize=12) + ax.set_ylabel(ylabel, fontsize=12) + ax.set_title(title, fontsize=13) + ax.legend(fontsize=11) + ax.grid(True, alpha=0.3) + +fig.suptitle('Toy LQ: Training Curves (3 seeds)', fontsize=14, y=1.02) +fig.tight_layout() +fig.savefig(os.path.join(output_dir, 'toy_training_curves.png'), dpi=150, bbox_inches='tight') +plt.close(fig) +print("Saved toy_training_curves.png") + +# Compare v1 (no term grad) vs v2 (with term grad) for credit bridge +if v1_data: + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + + # v1 credit bridge (no term grad matching) + v1_log = v1_data['log'] + ax.plot(v1_log['steps'], v1_log['credit_costate_cos'], + '--', color='red', label='Credit Bridge (w/o terminal grad)', alpha=0.8) + + # v2 credit bridge (with term grad) + v2_log = all_data[0]['log'] # seed 42 + ax.plot(v2_log['steps'], v2_log['credit_costate_cos'], + '-', color='green', label='Credit Bridge (w/ terminal grad)') + + # State bridge for reference + ax.plot(v2_log['steps'], v2_log['state_costate_cos'], + '-', color='orange', label='State Bridge') + + ax.set_xlabel('Training Step', fontsize=12) + ax.set_ylabel('Avg Costate Cosine', fontsize=12) + ax.set_title('Effect of Terminal Gradient Matching', fontsize=13) + ax.legend(fontsize=11) + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.1, 1.05) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_term_grad_effect.png'), dpi=150) + plt.close(fig) + print("Saved toy_term_grad_effect.png") + +# Bridge residual (from v1 which has it) +if v1_data and v1_data['log'].get('bridge_residual'): + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(v1_data['log']['steps'], v1_data['log']['bridge_residual'], '-', color='green') + ax.set_xlabel('Training Step', fontsize=12) + ax.set_ylabel('Bridge Residual', fontsize=12) + ax.set_title('Credit Bridge: Bridge Residual Over Training', fontsize=13) + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150) + plt.close(fig) + print("Saved toy_bridge_residual.png") + +# Print summary table +print("\n" + "="*80) +print("TOY LQ FINAL RESULTS (3 seeds, 8000 steps)") +print("="*80) + +for method in methods: + cos_vals = [] + rho_vals = [] + nudge_vals = [] + for data in all_data: + pl = data['final_per_layer'] + cos_vals.append(np.mean(pl[f'{method}_costate_cos'])) + rho_vals.append(np.mean(pl[f'{method}_rho'])) + nudge_vals.append(np.mean(pl[f'{method}_nudge'])) + + cos_mean, cos_std = np.mean(cos_vals), np.std(cos_vals) + rho_mean, rho_std = np.mean(rho_vals), np.std(rho_vals) + nudge_mean, nudge_std = np.mean(nudge_vals), np.std(nudge_vals) + + print(f"{labels[method]:<20} Cosine: {cos_mean:.4f}±{cos_std:.4f} " + f"ρ: {rho_mean:.4f}±{rho_std:.4f} " + f"Nudge: {nudge_mean:.4f}±{nudge_std:.4f}") diff --git a/experiments/toy_lq.py b/experiments/toy_lq.py new file mode 100644 index 0000000..4fd8919 --- /dev/null +++ b/experiments/toy_lq.py @@ -0,0 +1,395 @@ +""" +Phase A: Linear-Quadratic Residual Sanity Check. + +Fixed forward dynamics (no forward net training). +Only train feedback/bridge models. +Compare DFA, State Bridge, Credit Bridge against exact costate. + +System: + h_{l+1} = M_l h_l + sigma * xi_l, xi_l ~ N(0, I) + Phi(h_L, y) = 0.5 * ||C h_L - y||^2 + Exact costate: a_L = C^T (C h_L - y), a_l = M_l^T a_{l+1} +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from datetime import datetime + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test, bridge_residual +) + + +def generate_stable_dynamics(d, L, spectral_max=0.05, seed=42): + """Generate stable linear maps M_l = I + A_l with ||A_l||_2 <= spectral_max.""" + rng = np.random.RandomState(seed) + Ms = [] + for _ in range(L): + A = rng.randn(d, d).astype(np.float32) + # Scale to desired spectral norm + u, s, v = np.linalg.svd(A, full_matrices=False) + A = A * (spectral_max / s[0]) + M = np.eye(d, dtype=np.float32) + A + Ms.append(torch.from_numpy(M)) + return Ms # list of (d, d) + + +def rollout_forward(h0, Ms, sigma, L, device): + """Roll out forward dynamics: h_{l+1} = M_l h_l + sigma * xi_l.""" + batch = h0.shape[0] + d = h0.shape[1] + hiddens = [h0] + h = h0 + for l in range(L): + M = Ms[l].to(device) + noise = sigma * torch.randn(batch, d, device=device) + h = h @ M.T + noise + hiddens.append(h) + return hiddens # [h_0, ..., h_L] + + +def terminal_loss(hL, C, y): + """Phi(hL, y) = 0.5 * ||C hL - y||^2, returns per-sample loss.""" + diff = hL @ C.T - y # (batch, m) + return 0.5 * (diff ** 2).sum(dim=-1) # (batch,) + + +def exact_costate(hiddens, Ms, C, y, device): + """Compute exact costate a_l for all layers.""" + L = len(hiddens) - 1 + hL = hiddens[L] + # Terminal: a_L = C^T (C h_L - y) + diff = hL @ C.T - y # (batch, m) + a_L = diff @ C # (batch, d) + + costates = [None] * (L + 1) + costates[L] = a_L + for l in range(L - 1, -1, -1): + M = Ms[l].to(device) + costates[l] = costates[l + 1] @ M # a_l = M_l^T a_{l+1} -> a_{l+1} @ M + return costates + + +def make_forward_fn_from_layer(hiddens, Ms, C, y, sigma, start_layer, device): + """Create a function that rolls forward from layer start_layer and returns per-sample loss.""" + L = len(Ms) + + def forward_fn(h): + current = h + for l in range(start_layer, L): + M = Ms[l].to(device) + # No noise for perturbation test (deterministic rollout) + current = current @ M.T + return terminal_loss(current, C, y) + + return forward_fn + + +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + # Hyperparams + d = args.d_hidden # 64 + m = args.output_dim # 10 + L = args.num_layers # 12 + sigma = args.sigma # 0.03 + batch_size = args.batch_size # 256 + num_steps = args.num_steps # 5000 + lr_fb = args.lr_fb # 1e-3 + lam = args.lam # 0.1 + K = args.K # 8 + ema_momentum = args.ema_momentum # 0.995 + sigma_bridge = args.sigma_bridge # 0.03 + + print(f"=== Toy LQ Experiment ===") + print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}") + print(f"device={device}") + + # Generate fixed dynamics + Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed) + C = torch.randn(m, d, device=device) / np.sqrt(d) + + # DFA random feedback matrices + Bs_dfa = [] + for l in range(L + 1): + B = torch.randn(d, m, device=device) / np.sqrt(m) + Bs_dfa.append(B) + + # State Bridge model + state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=128, num_layers=2).to(device) + opt_state = optim.Adam(state_bridge.parameters(), lr=lr_fb) + + # Credit Bridge value net + value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=128, num_layers=2).to(device) + value_net_ema = create_ema_model(value_net) + opt_value = optim.Adam(value_net.parameters(), lr=lr_fb) + + # Training logs + log = { + 'steps': [], + 'state_bridge_loss': [], + 'credit_bridge_loss': [], + 'dfa_costate_cos': [], + 'state_costate_cos': [], + 'credit_costate_cos': [], + 'dfa_rho': [], + 'state_rho': [], + 'credit_rho': [], + 'dfa_nudge': [], + 'state_nudge': [], + 'credit_nudge': [], + 'bridge_residual': [], + } + + for step in range(1, num_steps + 1): + # Generate data + h0 = torch.randn(batch_size, d, device=device) + y = torch.randn(batch_size, m, device=device) + + # Forward rollout + hiddens = rollout_forward(h0, Ms, sigma, L, device) + hL = hiddens[L] + + # Terminal error + e_T = (hL @ C.T - y) # (batch, m) - gradient of Phi w.r.t. prediction + + # Terminal modulation code s = e_T (P=I) + s = e_T.detach() + + # ---- Train State Bridge ---- + state_loss = 0.0 + hL_detached = hL.detach() + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + pred_hL = state_bridge(h_l_det, t_l, s) + state_loss = state_loss + ((pred_hL - hL_detached) ** 2).sum(dim=-1).mean() + state_loss = state_loss / L + + opt_state.zero_grad() + state_loss.backward() + opt_state.step() + + # ---- Train Credit Bridge (value net) ---- + # Terminal boundary: V(h_L, 1, s) should equal Phi(h_L, y) + hL_det = hL.detach().requires_grad_(False) + t_L = torch.ones(batch_size, device=device) + true_loss = terminal_loss(hL_det, C, y).detach() + + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + t_l_next = torch.full((batch_size,), (l + 1) / L, device=device) + + V_l = value_net(h_l_det, t_l, s) + + # Generate noisy next states + with torch.no_grad(): + M = Ms[l].to(device) + h_next_det = hiddens[l + 1].detach() + + log_terms = [] + for k in range(K): + noise = sigma_bridge * torch.randn(batch_size, d, device=device) + h_next_noisy = h_next_det + noise + V_next = value_net_ema(h_next_noisy, t_l_next, s) + log_terms.append(-V_next / lam) + + log_terms_stack = torch.stack(log_terms, dim=-1) # (batch, K) + V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + + loss_bridge = loss_bridge / L + loss_value = loss_term + loss_bridge + + opt_value.zero_grad() + loss_value.backward() + opt_value.step() + update_ema(value_net, value_net_ema, ema_momentum) + + # ---- Evaluation ---- + if step % args.eval_every == 0 or step == 1: + with torch.no_grad(): + eval_batch = min(batch_size, 128) + h0_eval = torch.randn(eval_batch, d, device=device) + y_eval = torch.randn(eval_batch, m, device=device) + hiddens_eval = rollout_forward(h0_eval, Ms, sigma, L, device) + hL_eval = hiddens_eval[L] + e_T_eval = hL_eval @ C.T - y_eval + s_eval = e_T_eval.detach() + + # Exact costate + costates_exact = exact_costate(hiddens_eval, Ms, C, y_eval, device) + + # Compute credits for each method at each layer + dfa_cos_layers = [] + state_cos_layers = [] + credit_cos_layers = [] + dfa_rho_layers = [] + state_rho_layers = [] + credit_rho_layers = [] + dfa_nudge_layers = [] + state_nudge_layers = [] + credit_nudge_layers = [] + bridge_res_layers = [] + + for l in range(L + 1): + h_l = hiddens_eval[l].detach() + a_exact = costates_exact[l].detach() + t_l = torch.full((eval_batch,), l / L, device=device) + + # DFA credit + a_dfa = e_T_eval @ Bs_dfa[l].T # (batch, d) + + # State bridge credit + h_l_req = h_l.clone().requires_grad_(True) + pred_hL = state_bridge(h_l_req, t_l, s_eval) + # Loss through state bridge prediction + pred_out = pred_hL @ C.T # Use C as output projection for consistency + pred_loss = 0.5 * ((pred_out - y_eval) ** 2).sum(dim=-1) + a_state = torch.autograd.grad(pred_loss.sum(), h_l_req, create_graph=False)[0] + + # Credit bridge credit + h_l_req2 = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req2, t_l, s_eval) + a_credit = torch.autograd.grad(V_l.sum(), h_l_req2, create_graph=False)[0] + + # Costate cosine + dfa_cos_layers.append(cosine_similarity_batch(a_dfa, a_exact)) + state_cos_layers.append(cosine_similarity_batch(a_state, a_exact)) + credit_cos_layers.append(cosine_similarity_batch(a_credit, a_exact)) + + # Perturbation correlation and nudging (skip terminal layer for forward_fn) + if l < L: + fwd_fn = make_forward_fn_from_layer(hiddens_eval, Ms, C, y_eval, sigma, l, device) + + dfa_rho = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16) + state_rho = perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16) + credit_rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16) + dfa_rho_layers.append(dfa_rho) + state_rho_layers.append(state_rho) + credit_rho_layers.append(credit_rho) + + dfa_nud = nudging_test(h_l, a_dfa, fwd_fn, eta=0.01) + state_nud = nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01) + credit_nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01) + dfa_nudge_layers.append(dfa_nud) + state_nudge_layers.append(state_nud) + credit_nudge_layers.append(credit_nud) + + # Bridge residual for credit bridge + if l < L: + t_l_next = torch.full((eval_batch,), (l + 1) / L, device=device) + h_next = hiddens_eval[l + 1].detach() + noisy_list = [h_next + sigma_bridge * torch.randn_like(h_next) for _ in range(K)] + br = bridge_residual(value_net, value_net_ema, h_l, t_l, s_eval, + noisy_list, t_l_next, lam) + bridge_res_layers.append(br) + + # Average across layers + avg_dfa_cos = np.mean(dfa_cos_layers) + avg_state_cos = np.mean(state_cos_layers) + avg_credit_cos = np.mean(credit_cos_layers) + avg_dfa_rho = np.mean(dfa_rho_layers) + avg_state_rho = np.mean(state_rho_layers) + avg_credit_rho = np.mean(credit_rho_layers) + avg_dfa_nudge = np.mean(dfa_nudge_layers) + avg_state_nudge = np.mean(state_nudge_layers) + avg_credit_nudge = np.mean(credit_nudge_layers) + avg_bridge_res = np.mean(bridge_res_layers) if bridge_res_layers else 0.0 + + log['steps'].append(step) + log['dfa_costate_cos'].append(avg_dfa_cos) + log['state_costate_cos'].append(avg_state_cos) + log['credit_costate_cos'].append(avg_credit_cos) + log['dfa_rho'].append(avg_dfa_rho) + log['state_rho'].append(avg_state_rho) + log['credit_rho'].append(avg_credit_rho) + log['dfa_nudge'].append(avg_dfa_nudge) + log['state_nudge'].append(avg_state_nudge) + log['credit_nudge'].append(avg_credit_nudge) + log['bridge_residual'].append(avg_bridge_res) + log['state_bridge_loss'].append(state_loss.item()) + log['credit_bridge_loss'].append(loss_value.item()) + + print(f"Step {step}/{num_steps}") + print(f" Costate cos - DFA: {avg_dfa_cos:.4f}, State: {avg_state_cos:.4f}, Credit: {avg_credit_cos:.4f}") + print(f" Perturb rho - DFA: {avg_dfa_rho:.4f}, State: {avg_state_rho:.4f}, Credit: {avg_credit_rho:.4f}") + print(f" Nudging - DFA: {avg_dfa_nudge:.4f}, State: {avg_state_nudge:.4f}, Credit: {avg_credit_nudge:.4f}") + print(f" Bridge res - {avg_bridge_res:.4f}") + print(f" Losses - State: {state_loss.item():.4f}, Credit: {loss_value.item():.4f}") + print(f" Per-layer costate cos (credit): {['%.3f' % x for x in credit_cos_layers]}") + + # Save results + os.makedirs(args.output_dir, exist_ok=True) + results = { + 'config': vars(args), + 'log': log, + 'final_per_layer': { + 'dfa_costate_cos': dfa_cos_layers, + 'state_costate_cos': state_cos_layers, + 'credit_costate_cos': credit_cos_layers, + 'dfa_rho': dfa_rho_layers, + 'state_rho': state_rho_layers, + 'credit_rho': credit_rho_layers, + 'dfa_nudge': dfa_nudge_layers, + 'state_nudge': state_nudge_layers, + 'credit_nudge': credit_nudge_layers, + 'bridge_residual': bridge_res_layers, + } + } + + out_path = os.path.join(args.output_dir, f'toy_lq_seed{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {out_path}") + + # Also save models + torch.save(value_net.state_dict(), os.path.join(args.output_dir, f'value_net_seed{args.seed}.pt')) + torch.save(state_bridge.state_dict(), os.path.join(args.output_dir, f'state_bridge_seed{args.seed}.pt')) + + return results + + +def main(): + parser = argparse.ArgumentParser(description='Toy LQ Sanity Check') + parser.add_argument('--d_hidden', type=int, default=64) + parser.add_argument('--output_dim', type=int, default=10) + parser.add_argument('--num_layers', type=int, default=12) + parser.add_argument('--sigma', type=float, default=0.03) + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--num_steps', type=int, default=5000) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=8) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--sigma_bridge', type=float, default=0.03) + parser.add_argument('--eval_every', type=int, default=200) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--output_dir', type=str, default='results/toy_lq') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() diff --git a/experiments/toy_lq_sweep.py b/experiments/toy_lq_sweep.py new file mode 100644 index 0000000..ae82ef0 --- /dev/null +++ b/experiments/toy_lq_sweep.py @@ -0,0 +1,243 @@ +""" +Sweep over credit bridge hyperparameters to find a configuration +where the value field gradient actually aligns with the costate. + +Key hypothesis: the credit bridge needs sufficient noise (sigma_bridge) +and temperature (lambda) to make V_phi sensitive to cost-relevant directions. +""" +import os +import sys +import json +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from itertools import product + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from experiments.toy_lq import ( + generate_stable_dynamics, rollout_forward, terminal_loss, + exact_costate, make_forward_fn_from_layer +) +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test + + +def run_credit_bridge_config(config, device): + """Run credit bridge with specific hyperparameters and return final metrics.""" + d = 64 + m = 10 + L = 12 + sigma = 0.03 + batch_size = 256 + num_steps = config['num_steps'] + lr = config['lr'] + lam = config['lam'] + K = config['K'] + ema_momentum = config['ema_momentum'] + sigma_bridge = config['sigma_bridge'] + hidden_dim = config.get('hidden_dim', 128) + use_ln = config.get('use_ln', True) + + torch.manual_seed(42) + np.random.seed(42) + + Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=42) + C = torch.randn(m, d, device=device) / np.sqrt(d) + + # Value net - optionally without LayerNorm + value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=hidden_dim, num_layers=2).to(device) + if not use_ln: + value_net.ln = nn.Identity() + + value_net_ema = create_ema_model(value_net) + opt_value = optim.Adam(value_net.parameters(), lr=lr) + + best_cos = -1.0 + best_step = 0 + history = [] + + for step in range(1, num_steps + 1): + h0 = torch.randn(batch_size, d, device=device) + y = torch.randn(batch_size, m, device=device) + hiddens = rollout_forward(h0, Ms, sigma, L, device) + hL = hiddens[L] + e_T = hL @ C.T - y + s = e_T.detach() + true_loss = terminal_loss(hL.detach(), C, y).detach() + + # Terminal boundary + hL_det = hL.detach() + t_L = torch.ones(batch_size, device=device) + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + t_l_next = torch.full((batch_size,), (l + 1) / L, device=device) + + V_l = value_net(h_l_det, t_l, s) + + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K): + noise = sigma_bridge * torch.randn(batch_size, d, device=device) + h_noisy = h_next_det + noise + V_next = value_net_ema(h_noisy, t_l_next, s) + log_terms.append(-V_next / lam) + + log_terms_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + + loss_bridge = loss_bridge / L + total_loss = loss_term + loss_bridge + + opt_value.zero_grad() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + opt_value.step() + update_ema(value_net, value_net_ema, ema_momentum) + + # Quick evaluation + if step % 500 == 0 or step == num_steps: + with torch.no_grad(): + eval_batch = 128 + h0_e = torch.randn(eval_batch, d, device=device) + y_e = torch.randn(eval_batch, m, device=device) + hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device) + hL_e = hiddens_e[L] + e_T_e = hL_e @ C.T - y_e + s_e = e_T_e.detach() + costates = exact_costate(hiddens_e, Ms, C, y_e, device) + + cos_list = [] + rho_list = [] + nudge_list = [] + for l in range(L): + h_l = hiddens_e[l].detach() + t_l = torch.full((eval_batch,), l / L, device=device) + a_exact = costates[l].detach() + + h_l_req = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req, t_l, s_e) + a_credit = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0] + + cos_list.append(cosine_similarity_batch(a_credit, a_exact)) + + fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device) + rho = perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16) + rho_list.append(rho) + nud = nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01) + nudge_list.append(nud) + + avg_cos = np.mean(cos_list) + avg_rho = np.mean(rho_list) + avg_nudge = np.mean(nudge_list) + + if avg_cos > best_cos: + best_cos = avg_cos + best_step = step + + history.append({ + 'step': step, + 'avg_cos': avg_cos, + 'avg_rho': avg_rho, + 'avg_nudge': avg_nudge, + 'loss_term': loss_term.item(), + 'loss_bridge': loss_bridge.item(), + }) + + return { + 'best_cos': best_cos, + 'best_step': best_step, + 'final_cos': history[-1]['avg_cos'], + 'final_rho': history[-1]['avg_rho'], + 'final_nudge': history[-1]['avg_nudge'], + 'history': history, + } + + +def main(): + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + print(f"Device: {device}") + + # Sweep configurations + configs = [ + # Baseline (original) + {'name': 'base', 'lam': 0.1, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Larger noise + {'name': 'noise_0.1', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Much larger noise + {'name': 'noise_0.3', 'lam': 0.1, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Larger lambda + {'name': 'lam_1.0', 'lam': 1.0, 'sigma_bridge': 0.03, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Large noise + large lambda + {'name': 'noise_lam', 'lam': 1.0, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # No LayerNorm + {'name': 'no_ln', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False}, + # Larger value net + {'name': 'big_vnet', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 256, 'use_ln': True}, + # Slower EMA + {'name': 'ema_0.999', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.999, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # More K samples + {'name': 'K16', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 16, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Larger noise + large lambda + no LN + {'name': 'best_combo', 'lam': 1.0, 'sigma_bridge': 0.3, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': False}, + # Very large sigma + {'name': 'noise_1.0', 'lam': 1.0, 'sigma_bridge': 1.0, 'K': 8, 'lr': 1e-3, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + # Lower lr + {'name': 'lr_3e-4', 'lam': 0.1, 'sigma_bridge': 0.1, 'K': 8, 'lr': 3e-4, + 'ema_momentum': 0.995, 'num_steps': 5000, 'hidden_dim': 128, 'use_ln': True}, + ] + + results = {} + for cfg in configs: + name = cfg.pop('name') + print(f"\n{'='*50}") + print(f"Config: {name}") + print(f" {cfg}") + res = run_credit_bridge_config(cfg, device) + results[name] = res + print(f" Best cos: {res['best_cos']:.4f} (step {res['best_step']})") + print(f" Final cos: {res['final_cos']:.4f}, rho: {res['final_rho']:.4f}, nudge: {res['final_nudge']:.4f}") + cfg['name'] = name # restore + + # Print summary + print("\n" + "="*80) + print("SWEEP SUMMARY") + print("="*80) + print(f"{'Config':<20} {'Best Cos':<12} {'Final Cos':<12} {'Final Rho':<12} {'Final Nudge':<12}") + print("-"*68) + for name, res in results.items(): + print(f"{name:<20} {res['best_cos']:<12.4f} {res['final_cos']:<12.4f} " + f"{res['final_rho']:<12.4f} {res['final_nudge']:<12.4f}") + + # Save + os.makedirs('results/toy_lq', exist_ok=True) + with open('results/toy_lq/sweep_results.json', 'w') as f: + json.dump(results, f, indent=2) + print("\nSaved to results/toy_lq/sweep_results.json") + + +if __name__ == '__main__': + main() diff --git a/experiments/toy_lq_v2.py b/experiments/toy_lq_v2.py new file mode 100644 index 0000000..ab766b6 --- /dev/null +++ b/experiments/toy_lq_v2.py @@ -0,0 +1,327 @@ +""" +Phase A v2: Enhanced toy LQ experiment. + +Key improvements over v1: +1. Terminal gradient matching: V_phi at terminal layer should have grad_h V matching + the exact terminal gradient (this is LOCAL info, no hidden BP needed). +2. Larger noise sweep integrated. +3. Optional FM auxiliary for gradient smoothness. +4. Better diagnostics. + +The terminal gradient a_L = C^T(C h_L - y) is computed from output layer only, +so using it is allowed under the "no hidden BP anchor" constraint. +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from experiments.toy_lq import ( + generate_stable_dynamics, rollout_forward, terminal_loss, + exact_costate, make_forward_fn_from_layer +) +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test +) + + +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + d = args.d_hidden + m = args.output_dim + L = args.num_layers + sigma = args.sigma + batch_size = args.batch_size + num_steps = args.num_steps + lr = args.lr_fb + lam = args.lam + K = args.K + ema_momentum = args.ema_momentum + sigma_bridge = args.sigma_bridge + + print(f"=== Toy LQ v2 Experiment ===") + print(f"d={d}, m={m}, L={L}, sigma={sigma}, seed={args.seed}") + print(f"lam={lam}, sigma_bridge={sigma_bridge}, K={K}") + print(f"terminal_grad_weight={args.term_grad_weight}") + print(f"fm_weight={args.fm_weight}") + print(f"device={device}") + + Ms = generate_stable_dynamics(d, L, spectral_max=0.05, seed=args.seed) + C = torch.randn(m, d, device=device) / np.sqrt(d) + + # DFA + Bs_dfa = [torch.randn(d, m, device=device) / np.sqrt(m) for _ in range(L + 1)] + + # State Bridge + state_bridge = StateBridgeNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=128, num_layers=2).to(device) + opt_state = optim.Adam(state_bridge.parameters(), lr=lr) + + # Credit Bridge + value_net = ValueNet(d_hidden=d, s_dim=m, time_embed_dim=16, + hidden_dim=args.vnet_hidden, num_layers=args.vnet_layers).to(device) + value_net_ema = create_ema_model(value_net) + opt_value = optim.Adam(value_net.parameters(), lr=lr) + + log = {key: [] for key in [ + 'steps', + 'dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos', + 'dfa_rho', 'state_rho', 'credit_rho', + 'dfa_nudge', 'state_nudge', 'credit_nudge', + 'bridge_residual', 'state_bridge_loss', 'credit_bridge_loss', + 'term_loss', 'bridge_loss', 'term_grad_loss', 'fm_loss', + ]} + + for step in range(1, num_steps + 1): + h0 = torch.randn(batch_size, d, device=device) + y = torch.randn(batch_size, m, device=device) + hiddens = rollout_forward(h0, Ms, sigma, L, device) + hL = hiddens[L] + e_T = hL @ C.T - y + s = e_T.detach() + + # ---- Train State Bridge ---- + state_loss = 0.0 + hL_det = hL.detach() + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + pred_hL = state_bridge(h_l_det, t_l, s) + state_loss = state_loss + ((pred_hL - hL_det) ** 2).sum(dim=-1).mean() + state_loss = state_loss / L + + opt_state.zero_grad() + state_loss.backward() + opt_state.step() + + # ---- Train Credit Bridge ---- + # 1. Terminal boundary: V(h_L, 1, s) ≈ Phi(h_L, y) + hL_det = hL.detach() + t_L = torch.ones(batch_size, device=device) + true_loss = terminal_loss(hL_det, C, y).detach() + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # 2. Terminal gradient matching: grad_h V(h_L, 1, s) ≈ a_L^exact + # This uses only terminal-local information (no hidden BP) + loss_term_grad = torch.tensor(0.0, device=device) + if args.term_grad_weight > 0: + hL_req = hL.detach().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + # Exact terminal gradient: C^T (C h_L - y) + a_L_exact = (e_T @ C).detach() # (batch, d) -- stop grad on target + loss_term_grad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() + + # 3. Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch_size,), l / L, device=device) + t_l_next = torch.full((batch_size,), (l + 1) / L, device=device) + + V_l = value_net(h_l_det, t_l, s) + + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K): + noise = sigma_bridge * torch.randn(batch_size, d, device=device) + h_noisy = h_next_det + noise + V_next = value_net_ema(h_noisy, t_l_next, s) + log_terms.append(-V_next / lam) + log_terms_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_terms_stack, dim=-1) - np.log(K)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + loss_bridge = loss_bridge / L + + # 4. FM auxiliary (optional): enforce gradient smoothness + loss_fm = torch.tensor(0.0, device=device) + if args.fm_weight > 0: + for l in range(L): + tau = torch.rand(batch_size, 1, device=device) + h_l_det = hiddens[l].detach() + h_next_det = hiddens[l + 1].detach() + f_l = h_next_det - h_l_det # residual + + eps = torch.randn(batch_size, d, device=device) + h_mid = h_l_det + tau * f_l + (tau * (1 - tau)).sqrt() * sigma_bridge * eps + h_mid.requires_grad_(True) + + t_mid = torch.full((batch_size, 1), 0, device=device) + t_mid = (l + tau) / L + t_mid_flat = t_mid.squeeze(-1) + + V_mid = value_net(h_mid, t_mid_flat, s) + grad_V_mid = torch.autograd.grad(V_mid.sum(), h_mid, create_graph=True)[0] + + # Interpolated target gradient + # Get a_l and a_{l+1} from current value net (no create_graph for targets) + h_l_r = h_l_det.clone().requires_grad_(True) + t_l_v = torch.full((batch_size,), l / L, device=device) + V_l_ = value_net(h_l_r, t_l_v, s) + a_l = torch.autograd.grad(V_l_.sum(), h_l_r, create_graph=False)[0].detach() + + h_next_r = h_next_det.clone().requires_grad_(True) + t_next_v = torch.full((batch_size,), (l + 1) / L, device=device) + V_next_ = value_net(h_next_r, t_next_v, s) + a_next = torch.autograd.grad(V_next_.sum(), h_next_r, create_graph=False)[0].detach() + + target_grad = ((1 - tau) * a_l + tau * a_next).detach() + loss_fm = loss_fm + ((grad_V_mid - target_grad) ** 2).sum(dim=-1).mean() + loss_fm = loss_fm / L + + total_loss = (loss_term + + loss_bridge + + args.term_grad_weight * loss_term_grad + + args.fm_weight * loss_fm) + + opt_value.zero_grad() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + opt_value.step() + update_ema(value_net, value_net_ema, ema_momentum) + + # ---- Evaluation ---- + if step % args.eval_every == 0 or step == 1: + with torch.no_grad(): + eval_batch = 128 + h0_e = torch.randn(eval_batch, d, device=device) + y_e = torch.randn(eval_batch, m, device=device) + hiddens_e = rollout_forward(h0_e, Ms, sigma, L, device) + hL_e = hiddens_e[L] + e_T_e = hL_e @ C.T - y_e + s_e = e_T_e.detach() + costates = exact_costate(hiddens_e, Ms, C, y_e, device) + + dfa_cos, state_cos, credit_cos = [], [], [] + dfa_rho, state_rho, credit_rho = [], [], [] + dfa_nudge, state_nudge, credit_nudge = [], [], [] + bridge_res_list = [] + + for l in range(L): + h_l = hiddens_e[l].detach() + a_exact = costates[l].detach() + t_l = torch.full((eval_batch,), l / L, device=device) + + # DFA + a_dfa = e_T_e @ Bs_dfa[l].T + # State bridge + h_l_r1 = h_l.clone().requires_grad_(True) + pred_hL = state_bridge(h_l_r1, t_l, s_e) + pred_out = pred_hL @ C.T + pred_loss = 0.5 * ((pred_out - y_e) ** 2).sum(dim=-1) + a_state = torch.autograd.grad(pred_loss.sum(), h_l_r1, create_graph=False)[0] + # Credit bridge + h_l_r2 = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_r2, t_l, s_e) + a_credit = torch.autograd.grad(V_l.sum(), h_l_r2, create_graph=False)[0] + + dfa_cos.append(cosine_similarity_batch(a_dfa, a_exact)) + state_cos.append(cosine_similarity_batch(a_state, a_exact)) + credit_cos.append(cosine_similarity_batch(a_credit, a_exact)) + + fwd_fn = make_forward_fn_from_layer(hiddens_e, Ms, C, y_e, sigma, l, device) + + dfa_rho.append(perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=16)) + state_rho.append(perturbation_correlation(h_l, a_state.detach(), fwd_fn, epsilon=1e-3, M=16)) + credit_rho.append(perturbation_correlation(h_l, a_credit.detach(), fwd_fn, epsilon=1e-3, M=16)) + + dfa_nudge.append(nudging_test(h_l, a_dfa, fwd_fn, eta=0.01)) + state_nudge.append(nudging_test(h_l, a_state.detach(), fwd_fn, eta=0.01)) + credit_nudge.append(nudging_test(h_l, a_credit.detach(), fwd_fn, eta=0.01)) + + avg = lambda x: float(np.mean(x)) + log['steps'].append(step) + log['dfa_costate_cos'].append(avg(dfa_cos)) + log['state_costate_cos'].append(avg(state_cos)) + log['credit_costate_cos'].append(avg(credit_cos)) + log['dfa_rho'].append(avg(dfa_rho)) + log['state_rho'].append(avg(state_rho)) + log['credit_rho'].append(avg(credit_rho)) + log['dfa_nudge'].append(avg(dfa_nudge)) + log['state_nudge'].append(avg(state_nudge)) + log['credit_nudge'].append(avg(credit_nudge)) + log['state_bridge_loss'].append(state_loss.item()) + log['credit_bridge_loss'].append(total_loss.item()) + log['term_loss'].append(loss_term.item()) + log['bridge_loss'].append(loss_bridge.item()) + log['term_grad_loss'].append(loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else loss_term_grad) + log['fm_loss'].append(loss_fm.item() if isinstance(loss_fm, torch.Tensor) else loss_fm) + + print(f"Step {step}/{num_steps}") + print(f" Costate cos - DFA: {avg(dfa_cos):.4f}, State: {avg(state_cos):.4f}, Credit: {avg(credit_cos):.4f}") + print(f" Perturb rho - DFA: {avg(dfa_rho):.4f}, State: {avg(state_rho):.4f}, Credit: {avg(credit_rho):.4f}") + print(f" Nudging - DFA: {avg(dfa_nudge):.4f}, State: {avg(state_nudge):.4f}, Credit: {avg(credit_nudge):.4f}") + print(f" Losses - term: {loss_term.item():.4f}, bridge: {loss_bridge.item():.4f}, " + f"tgrad: {loss_term_grad.item() if isinstance(loss_term_grad, torch.Tensor) else 0:.4f}, " + f"fm: {loss_fm.item() if isinstance(loss_fm, torch.Tensor) else 0:.4f}") + print(f" Per-layer credit cos: {['%.3f' % x for x in credit_cos]}") + + # Save + os.makedirs(args.output_dir, exist_ok=True) + results = { + 'config': vars(args), + 'log': log, + 'final_per_layer': { + 'dfa_costate_cos': dfa_cos, + 'state_costate_cos': state_cos, + 'credit_costate_cos': credit_cos, + 'dfa_rho': dfa_rho, + 'state_rho': state_rho, + 'credit_rho': credit_rho, + 'dfa_nudge': dfa_nudge, + 'state_nudge': state_nudge, + 'credit_nudge': credit_nudge, + } + } + tag = f"seed{args.seed}_lam{args.lam}_sig{args.sigma_bridge}_tgw{args.term_grad_weight}_fm{args.fm_weight}" + out_path = os.path.join(args.output_dir, f'toy_lq_v2_{tag}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {out_path}") + return results + + +def main(): + parser = argparse.ArgumentParser(description='Toy LQ v2') + parser.add_argument('--d_hidden', type=int, default=64) + parser.add_argument('--output_dim', type=int, default=10) + parser.add_argument('--num_layers', type=int, default=12) + parser.add_argument('--sigma', type=float, default=0.03) + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--num_steps', type=int, default=8000) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=8) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--sigma_bridge', type=float, default=0.1) + parser.add_argument('--eval_every', type=int, default=500) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--output_dir', type=str, default='results/toy_lq') + parser.add_argument('--vnet_hidden', type=int, default=256) + parser.add_argument('--vnet_layers', type=int, default=3) + # Key new options + parser.add_argument('--term_grad_weight', type=float, default=1.0, + help='Weight for terminal gradient matching loss') + parser.add_argument('--fm_weight', type=float, default=0.0, + help='Weight for FM gradient smoothness auxiliary') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() diff --git a/methods/__init__.py b/methods/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/methods/__init__.py diff --git a/metrics/__init__.py b/metrics/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/metrics/__init__.py diff --git a/metrics/__pycache__/__init__.cpython-313.pyc b/metrics/__pycache__/__init__.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..0595726 --- /dev/null +++ b/metrics/__pycache__/__init__.cpython-313.pyc diff --git a/metrics/__pycache__/credit_metrics.cpython-313.pyc b/metrics/__pycache__/credit_metrics.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..ef62388 --- /dev/null +++ b/metrics/__pycache__/credit_metrics.cpython-313.pyc diff --git a/metrics/credit_metrics.py b/metrics/credit_metrics.py new file mode 100644 index 0000000..516dca2 --- /dev/null +++ b/metrics/credit_metrics.py @@ -0,0 +1,156 @@ +""" +Credit assignment diagnostic metrics: +1. Exact costate cosine (for toy LQ) +2. Local perturbation correlation rho_l +3. Nudging test Delta_l^nudge +4. Offline BP cosine Gamma_l +5. Bridge residual R_l +6. Feature drift M_l +""" +import torch +import torch.nn.functional as F +import numpy as np +from scipy.stats import pearsonr + + +def cosine_similarity_batch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Compute cosine similarity between a and b along last dim, averaged over batch.""" + a_flat = a.reshape(a.shape[0], -1) + b_flat = b.reshape(b.shape[0], -1) + cos = F.cosine_similarity(a_flat, b_flat, dim=-1) + return cos.mean().item() + + +def perturbation_correlation(h_l, a_l, forward_fn, epsilon=1e-3, M=32): + """ + Compute local perturbation correlation rho_l. + + Args: + h_l: (batch, d) hidden state at layer l + a_l: (batch, d) credit signal at layer l + forward_fn: callable that takes h_l -> scalar loss (averaged over batch dims handled inside) + epsilon: perturbation magnitude + M: number of random directions + + Returns: + rho: Pearson correlation between predicted and true loss changes + """ + batch_size, d = h_l.shape + device = h_l.device + + pred_list = [] + true_list = [] + + base_loss = forward_fn(h_l) # (batch,) or scalar + + for _ in range(M): + v = torch.randn(batch_size, d, device=device) + v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) + + # Predicted change: <a_l, epsilon * v> + delta_pred = (a_l * (epsilon * v)).sum(dim=-1) # (batch,) + + # True change: forward from perturbed h + perturbed_loss = forward_fn(h_l + epsilon * v) # (batch,) + delta_true = perturbed_loss - base_loss # (batch,) + + pred_list.append(delta_pred.detach().cpu().numpy()) + true_list.append(delta_true.detach().cpu().numpy()) + + pred_arr = np.concatenate(pred_list) + true_arr = np.concatenate(true_list) + + if np.std(pred_arr) < 1e-12 or np.std(true_arr) < 1e-12: + return 0.0 + + rho, _ = pearsonr(pred_arr, true_arr) + return float(rho) + + +def nudging_test(h_l, a_l, forward_fn, eta=0.01): + """ + Nudging test: check if moving h_l in -a_l direction decreases loss. + + Args: + h_l: (batch, d) hidden state + a_l: (batch, d) credit signal + forward_fn: callable h -> loss per sample (batch,) + eta: step size + + Returns: + mean delta_nudge (negative is good) + """ + rms_a = (a_l ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_normed = a_l / rms_a + h_nudged = h_l - eta * a_normed + + base_loss = forward_fn(h_l) + nudged_loss = forward_fn(h_nudged) + delta = (nudged_loss - base_loss).mean().item() + return delta + + +def offline_bp_cosine(a_l, bp_grad_l): + """ + Compute offline BP cosine similarity. + a_l: (batch, d) credit signal + bp_grad_l: (batch, d) true BP gradient at layer l + """ + return cosine_similarity_batch(a_l, bp_grad_l) + + +def bridge_residual(V_phi, V_bar_phi, h_l, t_l, s, h_l_next_noisy_list, t_l_next, lam=0.1): + """ + Compute bridge residual R_l. + + Args: + V_phi: value network + V_bar_phi: EMA target value network + h_l: (batch, d) + t_l: (batch,) + s: (batch, s_dim) + h_l_next_noisy_list: list of K tensors (batch, d), noisy next states + t_l_next: (batch,) + lam: temperature + + Returns: + mean absolute bridge residual + """ + with torch.no_grad(): + V_current = V_phi(h_l, t_l, s) # (batch,) + + # Compute soft-min target + K = len(h_l_next_noisy_list) + log_terms = [] + for h_next in h_l_next_noisy_list: + V_next = V_bar_phi(h_next, t_l_next, s) # (batch,) + log_terms.append(-V_next / lam) + + log_terms = torch.stack(log_terms, dim=-1) # (batch, K) + V_target = -lam * torch.logsumexp(log_terms, dim=-1) + lam * np.log(K) + + residual = (V_current - V_target).abs().mean().item() + return residual + + +def feature_drift(model_init_params, model_final_params): + """ + Compute per-layer feature drift M_l = ||W_final - W_init||_F / ||W_init||_F. + + Args: + model_init_params: dict of {name: tensor} initial parameters + model_final_params: dict of {name: tensor} final parameters + + Returns: + dict of {name: drift_ratio} + """ + drifts = {} + for name in model_init_params: + if name in model_final_params: + w_init = model_init_params[name] + w_final = model_final_params[name] + init_norm = w_init.norm().item() + if init_norm > 1e-8: + drift = (w_final - w_init).norm().item() / init_norm + drifts[name] = drift + return drifts diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/models/__init__.py diff --git a/models/__pycache__/__init__.cpython-313.pyc b/models/__pycache__/__init__.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..cb3f264 --- /dev/null +++ b/models/__pycache__/__init__.cpython-313.pyc diff --git a/models/__pycache__/residual_mlp.cpython-313.pyc b/models/__pycache__/residual_mlp.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..c758f50 --- /dev/null +++ b/models/__pycache__/residual_mlp.cpython-313.pyc diff --git a/models/__pycache__/state_bridge.cpython-313.pyc b/models/__pycache__/state_bridge.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..69e1071 --- /dev/null +++ b/models/__pycache__/state_bridge.cpython-313.pyc diff --git a/models/__pycache__/value_net.cpython-313.pyc b/models/__pycache__/value_net.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..a6187ee --- /dev/null +++ b/models/__pycache__/value_net.cpython-313.pyc diff --git a/models/residual_mlp.py b/models/residual_mlp.py new file mode 100644 index 0000000..c16778c --- /dev/null +++ b/models/residual_mlp.py @@ -0,0 +1,73 @@ +""" +Deep Residual MLP for classification. +Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head. +Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l)) +""" +import torch +import torch.nn as nn + + +class ResidualBlock(nn.Module): + """Single pre-LayerNorm residual MLP block.""" + + def __init__(self, d_hidden: int): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.w1 = nn.Linear(d_hidden, d_hidden) + self.w2 = nn.Linear(d_hidden, d_hidden) + # Small init for residual branch + nn.init.normal_(self.w2.weight, std=0.01) + nn.init.zeros_(self.w2.bias) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + """Returns the residual F_l(h), NOT h + F_l(h).""" + z = self.ln(h) + z = self.w1(z) + z = torch.nn.functional.gelu(z) + z = self.w2(z) + return z + + +class ResidualMLP(nn.Module): + """Deep residual MLP: embed -> L blocks -> LN -> output head.""" + + def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int): + super().__init__() + self.embed = nn.Linear(input_dim, d_hidden) + self.blocks = nn.ModuleList([ResidualBlock(d_hidden) for _ in range(num_blocks)]) + self.out_ln = nn.LayerNorm(d_hidden) + self.out_head = nn.Linear(d_hidden, num_classes) + self.num_blocks = num_blocks + self.d_hidden = d_hidden + + def forward(self, x: torch.Tensor, return_hidden: bool = False): + """ + Args: + x: (batch, input_dim) + return_hidden: if True, also return list of hidden states [h_0, ..., h_L] + Returns: + logits: (batch, num_classes) + hiddens: list of (batch, d_hidden) if return_hidden + """ + h = self.embed(x) + hiddens = [h] if return_hidden else None + + for block in self.blocks: + f = block(h) + h = h + f + if return_hidden: + hiddens.append(h) + + logits = self.out_head(self.out_ln(h)) + + if return_hidden: + return logits, hiddens + return logits + + def forward_from_layer(self, h: torch.Tensor, start_layer: int): + """Run forward from a given layer index to output. Used for perturbation tests.""" + for i in range(start_layer, self.num_blocks): + f = self.blocks[i](h) + h = h + f + logits = self.out_head(self.out_ln(h)) + return logits diff --git a/models/state_bridge.py b/models/state_bridge.py new file mode 100644 index 0000000..0a0e7aa --- /dev/null +++ b/models/state_bridge.py @@ -0,0 +1,35 @@ +""" +State Bridge predictor G_psi(h_l, t_l, s) -> predicted h_L. +Used by the State Bridge method. +""" +import torch +import torch.nn as nn +from .value_net import SinusoidalTimeEmbed + + +class StateBridgeNet(nn.Module): + """ + State predictor G_psi(h_l, t_l, s) -> predicted terminal state h_L. + """ + + def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32, + hidden_dim: int = 256, num_layers: int = 3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor: + """Returns predicted h_L as (batch, d_hidden).""" + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) diff --git a/models/value_net.py b/models/value_net.py new file mode 100644 index 0000000..3c72f75 --- /dev/null +++ b/models/value_net.py @@ -0,0 +1,77 @@ +""" +Value network V_phi(h_l, t_l, s) -> scalar. +Used by the Credit Bridge method. +Input: [LN(h_l), time_embed(t_l), s] concatenated. +""" +import torch +import torch.nn as nn +import math +import copy + + +class SinusoidalTimeEmbed(nn.Module): + """Sinusoidal positional encoding for scalar depth-time t_l = l/L.""" + + def __init__(self, embed_dim: int): + super().__init__() + self.embed_dim = embed_dim + + def forward(self, t: torch.Tensor) -> torch.Tensor: + """t: (batch,) or (batch, 1) scalar in [0,1].""" + if t.dim() == 1: + t = t.unsqueeze(-1) # (batch, 1) + half = self.embed_dim // 2 + freqs = torch.exp( + -math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half + ) + args = t * freqs.unsqueeze(0) # (batch, half) + return torch.cat([torch.sin(args), torch.cos(args)], dim=-1) # (batch, embed_dim) + + +class ValueNet(nn.Module): + """ + Scalar value network V_phi(h_l, t_l, s). + Inputs: + h: hidden state (batch, d_hidden) + t: depth-time scalar (batch,) in [0, 1] + s: terminal modulation code (batch, s_dim) + Output: + V: scalar (batch,) + """ + + def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32, + hidden_dim: int = 256, num_layers: int = 3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, 1)) + self.net = nn.Sequential(*layers) + + def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor: + """Returns V(h, t, s) as (batch,) scalar.""" + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp).squeeze(-1) + + +def create_ema_model(model: nn.Module) -> nn.Module: + """Create an EMA copy of a model.""" + ema = copy.deepcopy(model) + for p in ema.parameters(): + p.requires_grad_(False) + return ema + + +@torch.no_grad() +def update_ema(model: nn.Module, ema_model: nn.Module, momentum: float = 0.99): + """Update EMA model parameters.""" + for p, ep in zip(model.parameters(), ema_model.parameters()): + ep.data.mul_(momentum).add_(p.data, alpha=1 - momentum) diff --git a/report/toy_bridge_residual.png b/report/toy_bridge_residual.png Binary files differnew file mode 100644 index 0000000..03eeb47 --- /dev/null +++ b/report/toy_bridge_residual.png diff --git a/report/toy_per_layer_diagnostics.png b/report/toy_per_layer_diagnostics.png Binary files differnew file mode 100644 index 0000000..d31b188 --- /dev/null +++ b/report/toy_per_layer_diagnostics.png diff --git a/report/toy_term_grad_effect.png b/report/toy_term_grad_effect.png Binary files differnew file mode 100644 index 0000000..13f0458 --- /dev/null +++ b/report/toy_term_grad_effect.png diff --git a/report/toy_training_curves.png b/report/toy_training_curves.png Binary files differnew file mode 100644 index 0000000..cc3532b --- /dev/null +++ b/report/toy_training_curves.png diff --git a/results/smoke_test/results_fashionmnist.json b/results/smoke_test/results_fashionmnist.json new file mode 100644 index 0000000..8fd82c0 --- /dev/null +++ b/results/smoke_test/results_fashionmnist.json @@ -0,0 +1,511 @@ +{ + "42": { + "bp": { + "log": { + "train_loss": [ + 0.7028828698158264, + 0.5331447437604269, + 0.4640675885995229, + 0.416527880080541, + 0.3784152720451355 + ], + "train_acc": [ + 0.73755, + 0.8002166666666667, + 0.8256666666666667, + 0.8436666666666667, + 0.8591833333333333 + ], + "test_acc": [ + 0.7939, + 0.8157, + 0.8379, + 0.8606, + 0.8658 + ] + }, + "diagnostics": { + "bp_cosine": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + "perturbation_rho": [ + 0.9998708367347717, + 0.9998407959938049, + 0.9997469186782837, + 0.9998122453689575, + 0.9997262358665466, + 0.9996491074562073, + 0.9996585845947266, + 0.9995328783988953 + ], + "nudging": { + "0.001": [ + -0.0019641336984932423, + -0.00174636859446764, + -0.00154352025128901, + -0.0013810225063934922, + -0.0012467722408473492, + -0.0011190228397026658, + -0.0010304137831553817, + -0.0009783159475773573 + ], + "0.003": [ + -0.005858615506440401, + -0.005213461350649595, + -0.004612031392753124, + -0.004128905013203621, + -0.0037290011532604694, + -0.003348270896822214, + -0.003083921270444989, + -0.002928499598056078 + ], + "0.01": [ + -0.019138235598802567, + -0.017080917954444885, + -0.015158241614699364, + -0.013598069548606873, + -0.01229821052402258, + -0.01105712354183197, + -0.010194781236350536, + -0.009686501696705818 + ] + } + }, + "drift": { + "embed.weight": 1.1395236897587921, + "embed.bias": 0.6760135861021652, + "blocks.0.ln.weight": 0.0347786583006382, + "blocks.0.w1.weight": 0.8267150385727446, + "blocks.0.w1.bias": 0.7665003752239994, + "blocks.0.w2.weight": 2.522163826858937, + "blocks.1.ln.weight": 0.0373835414648056, + "blocks.1.w1.weight": 0.8094006579319112, + "blocks.1.w1.bias": 0.7074648417912711, + "blocks.1.w2.weight": 2.427488314293417, + "blocks.2.ln.weight": 0.0362338162958622, + "blocks.2.w1.weight": 0.7981608599321399, + "blocks.2.w1.bias": 0.6653993627306621, + "blocks.2.w2.weight": 2.3264717248101827, + "blocks.3.ln.weight": 0.03489774093031883, + "blocks.3.w1.weight": 0.8024856088475522, + "blocks.3.w1.bias": 0.6243261101573547, + "blocks.3.w2.weight": 2.2732641732905714, + "blocks.4.ln.weight": 0.036862026900053024, + "blocks.4.w1.weight": 0.7702369058900084, + "blocks.4.w1.bias": 0.671902653164234, + "blocks.4.w2.weight": 2.11229934068397, + "blocks.5.ln.weight": 0.04049132019281387, + "blocks.5.w1.weight": 0.74033776504041, + "blocks.5.w1.bias": 0.6447910547850285, + "blocks.5.w2.weight": 1.9146335569252138, + "blocks.6.ln.weight": 0.03797098994255066, + "blocks.6.w1.weight": 0.7241522377185389, + "blocks.6.w1.bias": 0.6486550903936706, + "blocks.6.w2.weight": 1.8210870137239685, + "blocks.7.ln.weight": 0.03962903097271919, + "blocks.7.w1.weight": 0.6992516513120532, + "blocks.7.w1.bias": 0.7021825186584477, + "blocks.7.w2.weight": 1.7902862835380957, + "out_ln.weight": 0.026629405096173286, + "out_head.weight": 0.5610428179907003, + "out_head.bias": 0.24151687322978704 + } + }, + "dfa": { + "log": { + "train_loss": [ + 1.4100907169977823, + 1.4334057479222615, + 1.4326289967854817, + 1.4254953683853149, + 1.4169784986495972 + ], + "train_acc": [ + 0.4805, + 0.49806666666666666, + 0.49725, + 0.5031166666666667, + 0.5098 + ], + "test_acc": [ + 0.5175, + 0.5282, + 0.5025, + 0.5027, + 0.5338 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.2141590267419815, + 0.0797841027379036, + 0.04707183688879013, + 0.0032249214127659798, + 0.024886084720492363, + -0.0050605954602360725, + 0.009758025407791138, + 0.020020857453346252 + ], + "perturbation_rho": [ + 0.07338915765285492, + -0.004745986312627792, + 0.018682830035686493, + 0.031185101717710495, + -0.02063235454261303, + 0.0006608979310840368, + 0.0, + 0.04759033024311066 + ], + "nudging": { + "0.001": [ + -3.5976991057395935e-06, + 2.3283064365386963e-09, + -9.313225746154785e-10, + -9.313225746154785e-09, + 4.656612873077393e-10, + -3.725290298461914e-09, + 1.3969838619232178e-09, + 9.313225746154785e-10 + ], + "0.003": [ + -1.0745832696557045e-05, + 2.3283064365386963e-09, + 1.6298145055770874e-09, + -2.3283064365386963e-09, + 6.752088665962219e-09, + 1.0244548320770264e-08, + -4.6566128730773926e-09, + 0.0 + ], + "0.01": [ + -3.5760458558797836e-05, + -9.313225746154785e-10, + 1.3271346688270569e-08, + -1.0244548320770264e-08, + -8.381903171539307e-09, + 4.1443854570388794e-08, + 8.847564458847046e-09, + 1.30385160446167e-08 + ] + } + }, + "drift": { + "embed.weight": 34.15137184586086, + "embed.bias": 25.886722942466992, + "blocks.0.ln.weight": 2.2413852214813232, + "blocks.0.w1.weight": 42.91370684219911, + "blocks.0.w1.bias": 42.38937429728957, + "blocks.0.w2.weight": 115.11173260217275, + "blocks.1.ln.weight": 2.0233230590820312, + "blocks.1.w1.weight": 36.64772802731374, + "blocks.1.w1.bias": 34.57367344043412, + "blocks.1.w2.weight": 84.1198032305499, + "blocks.2.ln.weight": 1.9000216722488403, + "blocks.2.w1.weight": 33.172328512058, + "blocks.2.w1.bias": 31.908254113444393, + "blocks.2.w2.weight": 78.32752075434591, + "blocks.3.ln.weight": 1.9099335670471191, + "blocks.3.w1.weight": 36.73019631908303, + "blocks.3.w1.bias": 32.60666919280332, + "blocks.3.w2.weight": 83.75068434979308, + "blocks.4.ln.weight": 1.891120195388794, + "blocks.4.w1.weight": 35.27032832987592, + "blocks.4.w1.bias": 38.017692746712825, + "blocks.4.w2.weight": 80.26417466790754, + "blocks.5.ln.weight": 2.0106024742126465, + "blocks.5.w1.weight": 42.09808335703852, + "blocks.5.w1.bias": 43.15100280108635, + "blocks.5.w2.weight": 94.64753309078039, + "blocks.6.ln.weight": 1.8941009044647217, + "blocks.6.w1.weight": 38.94163125135345, + "blocks.6.w1.bias": 38.10426587138794, + "blocks.6.w2.weight": 84.40488806201412, + "blocks.7.ln.weight": 1.9035111665725708, + "blocks.7.w1.weight": 38.65051912560395, + "blocks.7.w1.bias": 40.760402959190415, + "blocks.7.w2.weight": 81.97372863530312, + "out_ln.weight": 0.39740806818008423, + "out_head.weight": 3.609484615833081, + "out_head.bias": 0.8344298895862311 + } + }, + "state_bridge": { + "log": { + "train_loss": [ + 1.7336509002049765, + 1.5851847206751506, + 1.8742321704864502, + 1.8100628153483074, + 1.5580067304611207 + ], + "train_acc": [ + 0.32705, + 0.3616333333333333, + 0.2679166666666667, + 0.31853333333333333, + 0.4080166666666667 + ], + "test_acc": [ + 0.4036, + 0.4047, + 0.3046, + 0.4005, + 0.4651 + ], + "state_pred_error": [ + 5076234.573784879, + 343992182.36586666, + 517900685.38026667, + 557639895.7226666, + 383439664.5205333 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.35149621963500977, + 0.27368226647377014, + 0.045176729559898376, + 0.04587914049625397, + 0.06403794139623642, + 0.06599076837301254, + 0.10026843845844269, + 0.11267217993736267 + ], + "perturbation_rho": [ + 0.37996596097946167, + 0.0075237625278532505, + -0.017497196793556213, + 0.001783197745680809, + 0.026772135868668556, + 0.011043311096727848, + 0.0037202914245426655, + 0.008577261120080948 + ], + "nudging": { + "0.001": [ + -7.0138368755579e-05, + -1.0319054126739502e-06, + -6.495974957942963e-08, + -4.190951585769653e-09, + -3.3993273973464966e-08, + -2.2584572434425354e-08, + -6.123445928096771e-08, + -4.1443854570388794e-08 + ], + "0.003": [ + -0.00021021789871156216, + -3.0745286494493484e-06, + -1.0547228157520294e-07, + -4.866160452365875e-08, + -8.591450750827789e-08, + -6.938353180885315e-08, + -1.1292286217212677e-07, + -9.825453162193298e-08 + ], + "0.01": [ + -0.0006983885541558266, + -1.027202233672142e-05, + -3.688037395477295e-07, + -1.073349267244339e-07, + -2.153683453798294e-07, + -2.0815059542655945e-07, + -2.6938505470752716e-07, + -3.08966264128685e-07 + ] + } + }, + "drift": { + "embed.weight": 4.244627827603776, + "embed.bias": 2.284214237417933, + "blocks.0.ln.weight": 0.7376585006713867, + "blocks.0.w1.weight": 10.052587254542228, + "blocks.0.w1.bias": 12.863219234970343, + "blocks.0.w2.weight": 32.7856072339516, + "blocks.1.ln.weight": 0.8459606170654297, + "blocks.1.w1.weight": 17.10440998330836, + "blocks.1.w1.bias": 22.587196662115986, + "blocks.1.w2.weight": 48.94107595877311, + "blocks.2.ln.weight": 0.5776991248130798, + "blocks.2.w1.weight": 12.17296546064332, + "blocks.2.w1.bias": 14.464613959802604, + "blocks.2.w2.weight": 33.58238884289355, + "blocks.3.ln.weight": 0.6916943788528442, + "blocks.3.w1.weight": 11.021146543598332, + "blocks.3.w1.bias": 11.779720628235973, + "blocks.3.w2.weight": 25.18170826322489, + "blocks.4.ln.weight": 0.5363020300865173, + "blocks.4.w1.weight": 8.488676390390957, + "blocks.4.w1.bias": 11.968348972417077, + "blocks.4.w2.weight": 23.562556821259157, + "blocks.5.ln.weight": 0.749293863773346, + "blocks.5.w1.weight": 13.199470836216618, + "blocks.5.w1.bias": 17.384581140704626, + "blocks.5.w2.weight": 36.38642209120496, + "blocks.6.ln.weight": 0.45835214853286743, + "blocks.6.w1.weight": 10.662863447362852, + "blocks.6.w1.bias": 15.855775302559838, + "blocks.6.w2.weight": 28.70293410646866, + "blocks.7.ln.weight": 0.4122738838195801, + "blocks.7.w1.weight": 7.193545064718527, + "blocks.7.w1.bias": 7.340520731384394, + "blocks.7.w2.weight": 22.53669113456487, + "out_ln.weight": 0.06885236501693726, + "out_head.weight": 1.473306835980063, + "out_head.bias": 1.5084479134035678 + } + }, + "credit_bridge": { + "log": { + "train_loss": [ + 2.0466531958262126, + 2.2737758037567137, + 2.280587441889445, + 2.25820095837911, + 2.270971960576375 + ], + "train_acc": [ + 0.22976666666666667, + 0.13306666666666667, + 0.13921666666666666, + 0.15853333333333333, + 0.15521666666666667 + ], + "test_acc": [ + 0.1866, + 0.101, + 0.1455, + 0.1616, + 0.0958 + ], + "value_loss": [ + 1.105676996310552, + 0.06329173700014751, + 0.02014747195293506, + 0.01469622576336066, + 0.0050645585257560015 + ] + }, + "diagnostics": { + "bp_cosine": [ + -0.04117349535226822, + 0.006370606832206249, + 0.03125208616256714, + -0.015287259593605995, + -0.04197325184941292, + -0.021368175745010376, + 0.009605868719518185, + 0.029422588646411896 + ], + "perturbation_rho": [ + -0.09283949434757233, + -0.030718784779310226, + 0.00206748116761446, + 0.0, + -0.0060626850463449955, + -0.015737878158688545, + 0.02677903138101101, + -0.040315717458724976 + ], + "nudging": { + "0.001": [ + 3.939494490623474e-06, + -1.4901161193847656e-08, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "0.003": [ + 1.1788681149482727e-05, + -3.725290298461914e-09, + 1.862645149230957e-09, + 0.0, + 0.0, + -1.862645149230957e-09, + 1.862645149230957e-09, + 0.0 + ], + "0.01": [ + 3.923662006855011e-05, + 2.60770320892334e-08, + 7.450580596923828e-09, + -1.862645149230957e-09, + 1.862645149230957e-09, + 1.862645149230957e-09, + 1.862645149230957e-09, + 5.587935447692871e-09 + ] + } + }, + "drift": { + "embed.weight": 5.5176286267072285, + "embed.bias": 5.84115130180781, + "blocks.0.ln.weight": 0.7516779899597168, + "blocks.0.w1.weight": 13.130250396475844, + "blocks.0.w1.bias": 16.418874968953887, + "blocks.0.w2.weight": 35.524006089100524, + "blocks.1.ln.weight": 1.0247337818145752, + "blocks.1.w1.weight": 24.629494668532566, + "blocks.1.w1.bias": 31.389502706524333, + "blocks.1.w2.weight": 66.32116257377548, + "blocks.2.ln.weight": 1.0405563116073608, + "blocks.2.w1.weight": 19.820545045709153, + "blocks.2.w1.bias": 21.04420170821489, + "blocks.2.w2.weight": 47.599865577930004, + "blocks.3.ln.weight": 0.698647677898407, + "blocks.3.w1.weight": 11.782448138509505, + "blocks.3.w1.bias": 11.300932589985935, + "blocks.3.w2.weight": 29.52897326429494, + "blocks.4.ln.weight": 0.7269545197486877, + "blocks.4.w1.weight": 12.151248949786332, + "blocks.4.w1.bias": 11.545512427875654, + "blocks.4.w2.weight": 31.11028303079019, + "blocks.5.ln.weight": 0.7007301449775696, + "blocks.5.w1.weight": 10.093534441471926, + "blocks.5.w1.bias": 8.779689334059729, + "blocks.5.w2.weight": 24.2612493038314, + "blocks.6.ln.weight": 0.7457646727561951, + "blocks.6.w1.weight": 9.34015610077582, + "blocks.6.w1.bias": 7.819986138941612, + "blocks.6.w2.weight": 24.326245888357803, + "blocks.7.ln.weight": 0.7317199110984802, + "blocks.7.w1.weight": 10.73492434511088, + "blocks.7.w1.bias": 9.150645079981764, + "blocks.7.w2.weight": 27.11222424242763, + "out_ln.weight": 0.07713422179222107, + "out_head.weight": 1.4131541672079744, + "out_head.bias": 0.7976411656775528 + } + } + }, + "config": { + "dataset": "fashionmnist", + "d_hidden": 256, + "num_blocks": 8, + "batch_size": 128, + "epochs": 5, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42 + ], + "gpu": 0, + "output_dir": "results/smoke_test", + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/smoke_test2/results_fashionmnist.json b/results/smoke_test2/results_fashionmnist.json new file mode 100644 index 0000000..e4b5d70 --- /dev/null +++ b/results/smoke_test2/results_fashionmnist.json @@ -0,0 +1,721 @@ +{ + "42": { + "bp": { + "log": { + "train_loss": [ + 0.7028828698158264, + 0.5380959739049276, + 0.4825267461776733, + 0.45401095023155214, + 0.4324853138923645, + 0.41239778510729475, + 0.39849929070472717, + 0.38395428166389467, + 0.3671353324095408, + 0.3565144721984863, + 0.3396712650140127, + 0.32847159377733864, + 0.3156646738688151, + 0.3063636976877848, + 0.29311317729949954, + 0.281319753352801, + 0.27345897892316184, + 0.26703561499913536, + 0.26238394471009574, + 0.2581080759366353 + ], + "train_acc": [ + 0.73755, + 0.7990666666666667, + 0.8184166666666667, + 0.8286, + 0.8373333333333334, + 0.8439833333333333, + 0.8508166666666667, + 0.85665, + 0.8617666666666667, + 0.86675, + 0.87145, + 0.8759333333333333, + 0.87975, + 0.8850666666666667, + 0.8894166666666666, + 0.8943333333333333, + 0.8967166666666667, + 0.8991833333333333, + 0.90125, + 0.9020333333333334 + ], + "test_acc": [ + 0.7939, + 0.8145, + 0.8348, + 0.8457, + 0.8515, + 0.8507, + 0.86, + 0.8616, + 0.8629, + 0.867, + 0.8749, + 0.8758, + 0.8749, + 0.8826, + 0.8871, + 0.8893, + 0.8887, + 0.8914, + 0.8918, + 0.8933 + ] + }, + "diagnostics": { + "bp_cosine": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + "perturbation_rho": [ + 0.999779462814331, + 0.9997812509536743, + 0.9997234344482422, + 0.9995989799499512, + 0.9995447993278503, + 0.999480128288269, + 0.999315083026886, + 0.9990702271461487 + ], + "nudging": { + "0.001": [ + -0.0015512802638113499, + -0.0014223953476175666, + -0.0012828018516302109, + -0.0011290921829640865, + -0.0010026510572060943, + -0.0008882835973054171, + -0.0007996053318493068, + -0.0007174870697781444 + ], + "0.003": [ + -0.00463186064735055, + -0.004248819313943386, + -0.003833592403680086, + -0.003375905565917492, + -0.0029987930320203304, + -0.0026576737873256207, + -0.002393170492723584, + -0.0021477085538208485 + ], + "0.01": [ + -0.015182608738541603, + -0.013949813321232796, + -0.012605881318449974, + -0.011121007613837719, + -0.009890388697385788, + -0.008775115013122559, + -0.007910609245300293, + -0.0071043153293430805 + ] + } + }, + "drift": { + "embed.weight": 2.2156592696361637, + "embed.bias": 1.1767430604557616, + "blocks.0.ln.weight": 0.10304339230060577, + "blocks.0.w1.weight": 1.482853423570453, + "blocks.0.w1.bias": 1.6160260440173806, + "blocks.0.w2.weight": 4.957564419601416, + "blocks.1.ln.weight": 0.10342221707105637, + "blocks.1.w1.weight": 1.4925957913309882, + "blocks.1.w1.bias": 1.387226988584735, + "blocks.1.w2.weight": 4.8677799360581435, + "blocks.2.ln.weight": 0.10164804011583328, + "blocks.2.w1.weight": 1.5187247212963315, + "blocks.2.w1.bias": 1.315003727574603, + "blocks.2.w2.weight": 4.892590160655272, + "blocks.3.ln.weight": 0.0980648621916771, + "blocks.3.w1.weight": 1.5309313086145322, + "blocks.3.w1.bias": 1.183681086835403, + "blocks.3.w2.weight": 4.853244692493399, + "blocks.4.ln.weight": 0.10139396041631699, + "blocks.4.w1.weight": 1.4821627671323643, + "blocks.4.w1.bias": 1.3175104220848128, + "blocks.4.w2.weight": 4.605731857817927, + "blocks.5.ln.weight": 0.11211488395929337, + "blocks.5.w1.weight": 1.4651151472668016, + "blocks.5.w1.bias": 1.2329307722821297, + "blocks.5.w2.weight": 4.4285985689227365, + "blocks.6.ln.weight": 0.10427453368902206, + "blocks.6.w1.weight": 1.4545050310842818, + "blocks.6.w1.bias": 1.2120760166211149, + "blocks.6.w2.weight": 4.21210005778474, + "blocks.7.ln.weight": 0.10458954423666, + "blocks.7.w1.weight": 1.4114832006479092, + "blocks.7.w1.bias": 1.3654059522290227, + "blocks.7.w2.weight": 4.118203814443419, + "out_ln.weight": 0.07383835315704346, + "out_head.weight": 0.9730150380688232, + "out_head.bias": 0.5616520351930565 + } + }, + "dfa": { + "log": { + "train_loss": [ + 1.4100907169977823, + 1.436597176615397, + 1.4409521081288656, + 1.4391240961710612, + 1.431752833366394, + 1.4278661415100098, + 1.42600347849528, + 1.4258823367436726, + 1.4264747858683269, + 1.4247130299250286, + 1.4233300720850626, + 1.4235623852411905, + 1.423257282193502, + 1.4203537824630736, + 1.4199624996821085, + 1.4181565198898316, + 1.4212706031799316, + 1.4193874025344848, + 1.4183136430740357, + 1.4177714852015177 + ], + "train_acc": [ + 0.4805, + 0.4955833333333333, + 0.48483333333333334, + 0.48445, + 0.48668333333333336, + 0.48556666666666665, + 0.48233333333333334, + 0.48796666666666666, + 0.48468333333333335, + 0.48793333333333333, + 0.49155, + 0.49241666666666667, + 0.49341666666666667, + 0.4975, + 0.49946666666666667, + 0.5013666666666666, + 0.5008166666666667, + 0.5015833333333334, + 0.5064833333333333, + 0.5055166666666666 + ], + "test_acc": [ + 0.5175, + 0.5227, + 0.4511, + 0.5054, + 0.5238, + 0.4415, + 0.4969, + 0.5415, + 0.5148, + 0.463, + 0.5415, + 0.479, + 0.5198, + 0.5259, + 0.5167, + 0.5293, + 0.5023, + 0.513, + 0.5234, + 0.5205 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.18807855248451233, + 0.003098251298069954, + -0.0009170115226879716, + -0.0030592959374189377, + 0.001578816445544362, + -0.002526880707591772, + -0.0002142013981938362, + 0.002638747449964285 + ], + "perturbation_rho": [ + -0.02331582084298134, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.008287552744150162 + ], + "nudging": { + "0.001": [ + -1.016305759549141e-06, + 0.0, + 0.0, + 9.313225746154785e-10, + 0.0, + -1.862645149230957e-09, + 0.0, + 0.0 + ], + "0.003": [ + -3.2687094062566757e-06, + 0.0, + 0.0, + -1.3969838619232178e-09, + -9.313225746154785e-10, + -2.3283064365386963e-09, + -6.984919309616089e-10, + 0.0 + ], + "0.01": [ + -1.1139316484332085e-05, + 9.313225746154785e-10, + 9.313225746154785e-10, + 6.05359673500061e-09, + -9.313225746154785e-10, + -1.3969838619232178e-09, + 1.1641532182693481e-09, + -9.313225746154785e-10 + ] + } + }, + "drift": { + "embed.weight": 118.42757824386928, + "embed.bias": 91.53070095956268, + "blocks.0.ln.weight": 6.256274700164795, + "blocks.0.w1.weight": 139.8976340759611, + "blocks.0.w1.bias": 140.08208587154786, + "blocks.0.w2.weight": 332.94308388138757, + "blocks.1.ln.weight": 5.70231819152832, + "blocks.1.w1.weight": 129.75151109455834, + "blocks.1.w1.bias": 121.17432552115143, + "blocks.1.w2.weight": 248.7841682293662, + "blocks.2.ln.weight": 5.466833114624023, + "blocks.2.w1.weight": 124.61655550824435, + "blocks.2.w1.bias": 116.83509018118278, + "blocks.2.w2.weight": 237.4413053793645, + "blocks.3.ln.weight": 5.473834037780762, + "blocks.3.w1.weight": 130.23769942327985, + "blocks.3.w1.bias": 111.35214850610642, + "blocks.3.w2.weight": 253.67927391219192, + "blocks.4.ln.weight": 5.401853084564209, + "blocks.4.w1.weight": 138.8132811262774, + "blocks.4.w1.bias": 138.7305587293139, + "blocks.4.w2.weight": 257.0171443869646, + "blocks.5.ln.weight": 5.765963554382324, + "blocks.5.w1.weight": 155.34176828036547, + "blocks.5.w1.bias": 142.51555937014615, + "blocks.5.w2.weight": 294.4127013869835, + "blocks.6.ln.weight": 5.595874309539795, + "blocks.6.w1.weight": 145.1627005952125, + "blocks.6.w1.bias": 132.11466551434717, + "blocks.6.w2.weight": 265.955180270204, + "blocks.7.ln.weight": 5.609050750732422, + "blocks.7.w1.weight": 147.06679411967872, + "blocks.7.w1.bias": 140.14577658787653, + "blocks.7.w2.weight": 258.1442232900185, + "out_ln.weight": 1.0355137586593628, + "out_head.weight": 8.13612662750539, + "out_head.bias": 0.9719656640023888 + } + }, + "state_bridge": { + "log": { + "train_loss": [ + 1.7336509002049765, + 1.5758645205815633, + 1.9318606907526652, + 2.191984078470866, + 1.8327842120488484, + 1.8950384799957276, + 2.377087445449829, + 2.1417452257792156, + 1.9545116751352947, + 2.1638838397979736, + 2.023908238474528, + 2.336049980545044, + 1.9471053196589152, + 1.7918469515482585, + 1.7594309553146363, + 1.7820972992579143, + 1.7949362015406292, + 1.7877578330993653, + 1.7872745515823365, + 1.7890100787480672 + ], + "train_acc": [ + 0.32705, + 0.35613333333333336, + 0.25498333333333334, + 0.14148333333333332, + 0.23143333333333332, + 0.23958333333333334, + 0.07806666666666667, + 0.16218333333333335, + 0.20036666666666667, + 0.1637, + 0.23035, + 0.13533333333333333, + 0.21553333333333333, + 0.25556666666666666, + 0.29751666666666665, + 0.32916666666666666, + 0.3303333333333333, + 0.33866666666666667, + 0.33645, + 0.3411 + ], + "test_acc": [ + 0.4036, + 0.3937, + 0.2369, + 0.2137, + 0.201, + 0.0282, + 0.1012, + 0.1994, + 0.1525, + 0.1935, + 0.136, + 0.1863, + 0.2135, + 0.287, + 0.3171, + 0.3199, + 0.3193, + 0.309, + 0.3431, + 0.3284 + ], + "state_pred_error": [ + 5076234.573784879, + 451395873.88586664, + 9768031180.663467, + 22542289907.438934, + 14961824114.824533, + 30110774233.224533, + 11799262674.944, + 64511141264.315735, + 46851419274.717865, + 25131605056.98987, + 35883748757.77707, + 48375831022.7968, + 69333597113.5488, + 40594997328.827736, + 42734315291.71627, + 38351667650.01387, + 32476750280.567467, + 28322754527.232, + 25330875547.648, + 23083374417.237335 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.1429949849843979, + 0.3180725574493408, + 0.20951706171035767, + 0.2099049985408783, + 0.2332497388124466, + 0.18252244591712952, + 0.16670912504196167, + 0.19962334632873535 + ], + "perturbation_rho": [ + 0.22952523827552795, + 0.06709301471710205, + 0.0, + 0.02129420079290867, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "nudging": { + "0.001": [ + -2.1068379282951355e-05, + -8.009374141693115e-08, + 1.862645149230957e-09, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "0.003": [ + -6.297603249549866e-05, + -1.4621764421463013e-07, + 2.7939677238464355e-09, + -7.450580596923828e-09, + 5.587935447692871e-09, + 0.0, + 9.313225746154785e-10, + 0.0 + ], + "0.01": [ + -0.00020940043032169342, + -4.3585896492004395e-07, + 0.0, + -5.587935447692871e-09, + 8.381903171539307e-09, + 0.0, + -4.6566128730773926e-09, + 0.0 + ] + } + }, + "drift": { + "embed.weight": 10.2616062216752, + "embed.bias": 9.637697583673647, + "blocks.0.ln.weight": 2.2154791355133057, + "blocks.0.w1.weight": 28.129394231034535, + "blocks.0.w1.bias": 29.44761280088024, + "blocks.0.w2.weight": 83.23565733006456, + "blocks.1.ln.weight": 1.6048834323883057, + "blocks.1.w1.weight": 72.41049473988551, + "blocks.1.w1.bias": 83.01353544678275, + "blocks.1.w2.weight": 118.5292924947854, + "blocks.2.ln.weight": 1.528304100036621, + "blocks.2.w1.weight": 22.974203066797124, + "blocks.2.w1.bias": 24.217216068101656, + "blocks.2.w2.weight": 45.236771862572276, + "blocks.3.ln.weight": 2.2928571701049805, + "blocks.3.w1.weight": 48.83060261437247, + "blocks.3.w1.bias": 44.264646679202606, + "blocks.3.w2.weight": 77.90843201854017, + "blocks.4.ln.weight": 1.569786548614502, + "blocks.4.w1.weight": 38.25042383023526, + "blocks.4.w1.bias": 46.7989781903236, + "blocks.4.w2.weight": 61.850484751580424, + "blocks.5.ln.weight": 1.3877884149551392, + "blocks.5.w1.weight": 25.667927717045547, + "blocks.5.w1.bias": 27.47666495212993, + "blocks.5.w2.weight": 56.74081708072433, + "blocks.6.ln.weight": 2.3657233715057373, + "blocks.6.w1.weight": 68.78499791753522, + "blocks.6.w1.bias": 66.7112278827284, + "blocks.6.w2.weight": 88.32659623386007, + "blocks.7.ln.weight": 1.5114411115646362, + "blocks.7.w1.weight": 40.81772220069051, + "blocks.7.w1.bias": 45.46224641191049, + "blocks.7.w2.weight": 78.43904162151014, + "out_ln.weight": 0.27722063660621643, + "out_head.weight": 3.4145406581484017, + "out_head.bias": 1.9963410657002738 + } + }, + "credit_bridge": { + "log": { + "train_loss": [ + 1.4013389415105184, + 1.4349893891016643, + 1.442590291341146, + 1.4306699399312337, + 1.4177788179397584, + 1.4436163345336914, + 1.5327887171427408, + 1.5964955659866333, + 1.6370685063680013, + 1.6712419691085816, + 1.6618358172098795, + 1.6343142255147298, + 1.6132340278625488, + 1.5844669729232788, + 1.5616684883117675, + 1.541010483233134, + 1.525136720085144, + 1.5156665041605633, + 1.5106076243718465, + 1.5088000661214194 + ], + "train_acc": [ + 0.4805333333333333, + 0.48885, + 0.4822, + 0.48246666666666665, + 0.48988333333333334, + 0.4848, + 0.4645166666666667, + 0.44521666666666665, + 0.4347, + 0.4149833333333333, + 0.41485, + 0.42995, + 0.4308, + 0.43275, + 0.44, + 0.4475, + 0.44975, + 0.45098333333333335, + 0.4527333333333333, + 0.45531666666666665 + ], + "test_acc": [ + 0.4889, + 0.4758, + 0.4915, + 0.5234, + 0.5192, + 0.4658, + 0.4303, + 0.4477, + 0.4427, + 0.4192, + 0.4341, + 0.4369, + 0.4415, + 0.4713, + 0.4554, + 0.4785, + 0.4729, + 0.47, + 0.4671, + 0.4741 + ], + "value_loss": [ + 0.4515702169219653, + 0.10145944621364275, + 0.0791723535656929, + 0.06821208957632383, + 0.048990531079967814, + 0.04999728363355001, + 0.0479251907547315, + 0.04561474083264669, + 0.03763539131681124, + 0.040053354968627296, + 0.03465893845955531, + 0.029476907690366108, + 0.022895141302545864, + 0.020180800672868888, + 0.01736196913222472, + 0.012846514955163002, + 0.009126213994373878, + 0.008726185441513856, + 0.008221083876490592, + 0.007363871405273676 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.15763823688030243, + 0.11252982914447784, + 0.23011641204357147, + 0.21486178040504456, + 0.20616328716278076, + 0.2209414690732956, + 0.22547751665115356, + 0.21971024572849274 + ], + "perturbation_rho": [ + 0.022450335323810577, + 0.03968421369791031, + -0.010952746495604515, + -0.02441849745810032, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "nudging": { + "0.001": [ + -2.5480985641479492e-06, + -3.864988684654236e-08, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "0.003": [ + -7.3623377829790115e-06, + -4.377216100692749e-08, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "0.01": [ + -2.4445587769150734e-05, + -8.102506399154663e-08, + -1.862645149230957e-09, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + } + }, + "drift": { + "embed.weight": 56.40174417770624, + "embed.bias": 40.65522803248636, + "blocks.0.ln.weight": 3.5452845096588135, + "blocks.0.w1.weight": 60.29274021947322, + "blocks.0.w1.bias": 39.86138087773039, + "blocks.0.w2.weight": 158.13629039802228, + "blocks.1.ln.weight": 4.1766037940979, + "blocks.1.w1.weight": 97.47012476867003, + "blocks.1.w1.bias": 94.45373800951319, + "blocks.1.w2.weight": 182.69479749878457, + "blocks.2.ln.weight": 3.4605705738067627, + "blocks.2.w1.weight": 73.73554694789455, + "blocks.2.w1.bias": 70.46643111415602, + "blocks.2.w2.weight": 154.24816844331818, + "blocks.3.ln.weight": 3.458436965942383, + "blocks.3.w1.weight": 81.09956449846064, + "blocks.3.w1.bias": 70.54884627062913, + "blocks.3.w2.weight": 164.27884471330486, + "blocks.4.ln.weight": 3.470608949661255, + "blocks.4.w1.weight": 84.05857705229029, + "blocks.4.w1.bias": 85.61940602638978, + "blocks.4.w2.weight": 160.78897537314361, + "blocks.5.ln.weight": 3.7341277599334717, + "blocks.5.w1.weight": 97.50908269356889, + "blocks.5.w1.bias": 91.11903815503673, + "blocks.5.w2.weight": 187.97349614569833, + "blocks.6.ln.weight": 3.5159947872161865, + "blocks.6.w1.weight": 91.22877884822417, + "blocks.6.w1.bias": 83.71561937882119, + "blocks.6.w2.weight": 167.99172531428107, + "blocks.7.ln.weight": 3.4693915843963623, + "blocks.7.w1.weight": 88.21253603229287, + "blocks.7.w1.bias": 85.04756223959555, + "blocks.7.w2.weight": 156.73218429415434, + "out_ln.weight": 0.8287110924720764, + "out_head.weight": 7.02544389908901, + "out_head.bias": 2.6789805309811348 + } + } + }, + "config": { + "dataset": "fashionmnist", + "d_hidden": 256, + "num_blocks": 8, + "batch_size": 128, + "epochs": 20, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42 + ], + "gpu": 0, + "output_dir": "results/smoke_test2", + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/toy_lq/state_bridge_seed42.pt b/results/toy_lq/state_bridge_seed42.pt Binary files differnew file mode 100644 index 0000000..a87e99d --- /dev/null +++ b/results/toy_lq/state_bridge_seed42.pt diff --git a/results/toy_lq/sweep_results.json b/results/toy_lq/sweep_results.json new file mode 100644 index 0000000..30c3f49 --- /dev/null +++ b/results/toy_lq/sweep_results.json @@ -0,0 +1,1070 @@ +{ + "base": { + "best_cos": 0.28987203165888786, + "best_step": 500, + "final_cos": -0.0006980087685709199, + "final_rho": 0.00831739305673788, + "final_nudge": 0.0027815375166634717, + "history": [ + { + "step": 500, + "avg_cos": 0.28987203165888786, + "avg_rho": 0.30755721777677536, + "avg_nudge": -0.10750458451608817, + "loss_term": 0.4910352826118469, + "loss_bridge": 0.21142145991325378 + }, + { + "step": 1000, + "avg_cos": 0.17731603110829988, + "avg_rho": 0.17737891773382822, + "avg_nudge": -0.06381315520654122, + "loss_term": 0.1414657086133957, + "loss_bridge": 0.16202926635742188 + }, + { + "step": 1500, + "avg_cos": 0.07300025007377069, + "avg_rho": 0.07345063022027414, + "avg_nudge": -0.02641519942941765, + "loss_term": 0.07946588099002838, + "loss_bridge": 0.08628662675619125 + }, + { + "step": 2000, + "avg_cos": 0.05074742032835881, + "avg_rho": 0.048574333622430764, + "avg_nudge": -0.016719758200148743, + "loss_term": 0.06007867306470871, + "loss_bridge": 0.016275471076369286 + }, + { + "step": 2500, + "avg_cos": 0.045629044994711876, + "avg_rho": 0.0567939051737388, + "avg_nudge": -0.016934571011612814, + "loss_term": 0.04579576849937439, + "loss_bridge": 0.008496936410665512 + }, + { + "step": 3000, + "avg_cos": 0.06145744491368532, + "avg_rho": 0.05695042138298353, + "avg_nudge": -0.020737762407710154, + "loss_term": 0.04881729558110237, + "loss_bridge": 0.013217507861554623 + }, + { + "step": 3500, + "avg_cos": 0.04631757119204849, + "avg_rho": 0.0402289762472113, + "avg_nudge": -0.015406294725835323, + "loss_term": 0.04255056008696556, + "loss_bridge": 0.01808304153382778 + }, + { + "step": 4000, + "avg_cos": 0.07397006265819073, + "avg_rho": 0.07162392682706316, + "avg_nudge": -0.025235373992472887, + "loss_term": 0.05486675351858139, + "loss_bridge": 0.02182137593626976 + }, + { + "step": 4500, + "avg_cos": 0.07251314166933298, + "avg_rho": 0.07301904633641243, + "avg_nudge": -0.027378834628810484, + "loss_term": 0.027670202776789665, + "loss_bridge": 0.013123266398906708 + }, + { + "step": 5000, + "avg_cos": -0.0006980087685709199, + "avg_rho": 0.00831739305673788, + "avg_nudge": 0.0027815375166634717, + "loss_term": 0.021854262799024582, + "loss_bridge": 0.0122066093608737 + } + ] + }, + "noise_0.1": { + "best_cos": 0.28885648772120476, + "best_step": 500, + "final_cos": 0.02330129671220978, + "final_rho": 0.03084004654859503, + "final_nudge": -0.005932313855737448, + "history": [ + { + "step": 500, + "avg_cos": 0.28885648772120476, + "avg_rho": 0.306567445397377, + "avg_nudge": -0.1071426725635926, + "loss_term": 0.4925232529640198, + "loss_bridge": 0.20981840789318085 + }, + { + "step": 1000, + "avg_cos": 0.17456108083327612, + "avg_rho": 0.17607956007122993, + "avg_nudge": -0.06276426805804174, + "loss_term": 0.1324017345905304, + "loss_bridge": 0.1835721731185913 + }, + { + "step": 1500, + "avg_cos": 0.08755840392162402, + "avg_rho": 0.08556619860852759, + "avg_nudge": -0.03154423305143913, + "loss_term": 0.07669232785701752, + "loss_bridge": 0.10240059345960617 + }, + { + "step": 2000, + "avg_cos": 0.05209386876473824, + "avg_rho": 0.04985952338514229, + "avg_nudge": -0.01653688432027896, + "loss_term": 0.06218753010034561, + "loss_bridge": 0.02190624177455902 + }, + { + "step": 2500, + "avg_cos": 0.038306045811623335, + "avg_rho": 0.047700356847296156, + "avg_nudge": -0.01320391776971519, + "loss_term": 0.06128765642642975, + "loss_bridge": 0.01958899199962616 + }, + { + "step": 3000, + "avg_cos": 0.059713449950019516, + "avg_rho": 0.054914615117013454, + "avg_nudge": -0.01939693869402011, + "loss_term": 0.04238733649253845, + "loss_bridge": 0.006958359386771917 + }, + { + "step": 3500, + "avg_cos": 0.04283027160757532, + "avg_rho": 0.028526597811530035, + "avg_nudge": -0.01381517636279265, + "loss_term": 0.024001002311706543, + "loss_bridge": 0.006152431480586529 + }, + { + "step": 4000, + "avg_cos": 0.0656419579560558, + "avg_rho": 0.06320031352030735, + "avg_nudge": -0.02216823499960204, + "loss_term": 0.04236245155334473, + "loss_bridge": 0.01628262549638748 + }, + { + "step": 4500, + "avg_cos": 0.05787398883452018, + "avg_rho": 0.0610881638713181, + "avg_nudge": -0.02213552314788103, + "loss_term": 0.024835661053657532, + "loss_bridge": 0.008629711344838142 + }, + { + "step": 5000, + "avg_cos": 0.02330129671220978, + "avg_rho": 0.03084004654859503, + "avg_nudge": -0.005932313855737448, + "loss_term": 0.035377949476242065, + "loss_bridge": 0.020307490602135658 + } + ] + }, + "noise_0.3": { + "best_cos": 0.28320933257540065, + "best_step": 500, + "final_cos": -0.007130145425132166, + "final_rho": 0.003303757131410142, + "final_nudge": 0.004990910800794761, + "history": [ + { + "step": 500, + "avg_cos": 0.28320933257540065, + "avg_rho": 0.30105741570393246, + "avg_nudge": -0.10515742873152097, + "loss_term": 0.44116002321243286, + "loss_bridge": 0.210750013589859 + }, + { + "step": 1000, + "avg_cos": 0.1680816188454628, + "avg_rho": 0.1721916707853476, + "avg_nudge": -0.060552582144737244, + "loss_term": 0.14207889139652252, + "loss_bridge": 0.16882070899009705 + }, + { + "step": 1500, + "avg_cos": 0.08877109829336405, + "avg_rho": 0.08808179199695587, + "avg_nudge": -0.03221472934819758, + "loss_term": 0.12282441556453705, + "loss_bridge": 0.07336755841970444 + }, + { + "step": 2000, + "avg_cos": 0.04903770827998718, + "avg_rho": 0.04795513402981063, + "avg_nudge": -0.016069856472313404, + "loss_term": 0.06662434339523315, + "loss_bridge": 0.02520526573061943 + }, + { + "step": 2500, + "avg_cos": 0.05001032492145896, + "avg_rho": 0.053442057532568775, + "avg_nudge": -0.01783018947268526, + "loss_term": 0.04614880681037903, + "loss_bridge": 0.015394347719848156 + }, + { + "step": 3000, + "avg_cos": 0.03979540191357955, + "avg_rho": 0.03453826089389622, + "avg_nudge": -0.012916556559503078, + "loss_term": 0.04155049845576286, + "loss_bridge": 0.007860729470849037 + }, + { + "step": 3500, + "avg_cos": 0.06703585013747215, + "avg_rho": 0.04949762811884284, + "avg_nudge": -0.02297227829694748, + "loss_term": 0.059078700840473175, + "loss_bridge": 0.030750930309295654 + }, + { + "step": 4000, + "avg_cos": 0.06641693723698457, + "avg_rho": 0.06286459214364488, + "avg_nudge": -0.021883966866880655, + "loss_term": 0.025344880297780037, + "loss_bridge": 0.008583566173911095 + }, + { + "step": 4500, + "avg_cos": 0.03687915485352278, + "avg_rho": 0.028895257118468482, + "avg_nudge": -0.01225966370354096, + "loss_term": 0.020479349419474602, + "loss_bridge": 0.00925756897777319 + }, + { + "step": 5000, + "avg_cos": -0.007130145425132166, + "avg_rho": 0.003303757131410142, + "avg_nudge": 0.004990910800794761, + "loss_term": 0.03882830590009689, + "loss_bridge": 0.022515198215842247 + } + ] + }, + "lam_1.0": { + "best_cos": 0.2899630069732666, + "best_step": 500, + "final_cos": 0.0024576462844076255, + "final_rho": 0.010072723651925722, + "final_nudge": 0.0017782459035515785, + "history": [ + { + "step": 500, + "avg_cos": 0.2899630069732666, + "avg_rho": 0.3076498980323474, + "avg_nudge": -0.10753695170084636, + "loss_term": 0.49082812666893005, + "loss_bridge": 0.21152979135513306 + }, + { + "step": 1000, + "avg_cos": 0.17747302974263826, + "avg_rho": 0.17881879458824793, + "avg_nudge": -0.06420790683478117, + "loss_term": 0.1407778263092041, + "loss_bridge": 0.15578678250312805 + }, + { + "step": 1500, + "avg_cos": 0.08065215001503627, + "avg_rho": 0.08029377926141024, + "avg_nudge": -0.029326035796354216, + "loss_term": 0.07120706140995026, + "loss_bridge": 0.09130121767520905 + }, + { + "step": 2000, + "avg_cos": 0.0563324602941672, + "avg_rho": 0.058026942308060825, + "avg_nudge": -0.01875296530003349, + "loss_term": 0.06783974170684814, + "loss_bridge": 0.01712188683450222 + }, + { + "step": 2500, + "avg_cos": 0.039659525423000254, + "avg_rho": 0.04613731553157171, + "avg_nudge": -0.013568413676694036, + "loss_term": 0.0634445995092392, + "loss_bridge": 0.02091756835579872 + }, + { + "step": 3000, + "avg_cos": 0.04386034506993989, + "avg_rho": 0.039356538482631244, + "avg_nudge": -0.01407914562150836, + "loss_term": 0.05031818151473999, + "loss_bridge": 0.01259439717978239 + }, + { + "step": 3500, + "avg_cos": 0.0481002099889641, + "avg_rho": 0.03689269566287597, + "avg_nudge": -0.01590493693947792, + "loss_term": 0.03421805799007416, + "loss_bridge": 0.014973677694797516 + }, + { + "step": 4000, + "avg_cos": 0.07999587121109168, + "avg_rho": 0.07683291751891375, + "avg_nudge": -0.02731190746029218, + "loss_term": 0.04726963862776756, + "loss_bridge": 0.017824556678533554 + }, + { + "step": 4500, + "avg_cos": 0.07246351769814889, + "avg_rho": 0.07343382605661948, + "avg_nudge": -0.027415843835721414, + "loss_term": 0.02768528461456299, + "loss_bridge": 0.013396943919360638 + }, + { + "step": 5000, + "avg_cos": 0.0024576462844076255, + "avg_rho": 0.010072723651925722, + "avg_nudge": 0.0017782459035515785, + "loss_term": 0.02114972099661827, + "loss_bridge": 0.011466547846794128 + } + ] + }, + "noise_lam": { + "best_cos": 0.28980905935168266, + "best_step": 500, + "final_cos": 0.000333610107190907, + "final_rho": 0.009229002442831794, + "final_nudge": 0.0020260093733668327, + "history": [ + { + "step": 500, + "avg_cos": 0.28980905935168266, + "avg_rho": 0.30755361666282016, + "avg_nudge": -0.10748158146937688, + "loss_term": 0.4891643524169922, + "loss_bridge": 0.21093463897705078 + }, + { + "step": 1000, + "avg_cos": 0.1692807301878929, + "avg_rho": 0.17357065031925836, + "avg_nudge": -0.060928904761870704, + "loss_term": 0.12851648032665253, + "loss_bridge": 0.1991802453994751 + }, + { + "step": 1500, + "avg_cos": 0.08976124723752339, + "avg_rho": 0.08664474201699097, + "avg_nudge": -0.032548267083863415, + "loss_term": 0.08289942145347595, + "loss_bridge": 0.08361449092626572 + }, + { + "step": 2000, + "avg_cos": 0.0449219069754084, + "avg_rho": 0.03701058775186539, + "avg_nudge": -0.013820620176071921, + "loss_term": 0.07352523505687714, + "loss_bridge": 0.031066572293639183 + }, + { + "step": 2500, + "avg_cos": 0.0341803894067804, + "avg_rho": 0.04086895197785149, + "avg_nudge": -0.011834259377792478, + "loss_term": 0.05221429467201233, + "loss_bridge": 0.013445280492305756 + }, + { + "step": 3000, + "avg_cos": 0.043432267770792045, + "avg_rho": 0.035848827101290226, + "avg_nudge": -0.0138394293996195, + "loss_term": 0.08377102017402649, + "loss_bridge": 0.02437596395611763 + }, + { + "step": 3500, + "avg_cos": 0.04308730812044814, + "avg_rho": 0.02904196917855491, + "avg_nudge": -0.013877601828426123, + "loss_term": 0.028318829834461212, + "loss_bridge": 0.0096125528216362 + }, + { + "step": 4000, + "avg_cos": 0.07217423028002183, + "avg_rho": 0.07531639956869185, + "avg_nudge": -0.024754961564516027, + "loss_term": 0.04526568949222565, + "loss_bridge": 0.01877186819911003 + }, + { + "step": 4500, + "avg_cos": 0.041415012208744884, + "avg_rho": 0.039630077119606234, + "avg_nudge": -0.013629865366965532, + "loss_term": 0.0305576603859663, + "loss_bridge": 0.011333338916301727 + }, + { + "step": 5000, + "avg_cos": 0.000333610107190907, + "avg_rho": 0.009229002442831794, + "avg_nudge": 0.0020260093733668327, + "loss_term": 0.020293015986680984, + "loss_bridge": 0.00792770553380251 + } + ] + }, + "no_ln": { + "best_cos": 0.2994285201032956, + "best_step": 500, + "final_cos": -0.027601251068214577, + "final_rho": -0.03011056105606258, + "final_nudge": 0.013365049380809069, + "history": [ + { + "step": 500, + "avg_cos": 0.2994285201032956, + "avg_rho": 0.3287110353509585, + "avg_nudge": -0.11226618165771167, + "loss_term": 0.5378487706184387, + "loss_bridge": 0.20209679007530212 + }, + { + "step": 1000, + "avg_cos": 0.23427631705999374, + "avg_rho": 0.24483238657315573, + "avg_nudge": -0.08577251620590687, + "loss_term": 0.14587438106536865, + "loss_bridge": 0.17536549270153046 + }, + { + "step": 1500, + "avg_cos": 0.10971186744670074, + "avg_rho": 0.1081712432205677, + "avg_nudge": -0.039541066934665046, + "loss_term": 0.09813931584358215, + "loss_bridge": 0.1144137978553772 + }, + { + "step": 2000, + "avg_cos": 0.17568999342620373, + "avg_rho": 0.18043102137744427, + "avg_nudge": -0.06037815675760309, + "loss_term": 0.12119113653898239, + "loss_bridge": 0.05170433223247528 + }, + { + "step": 2500, + "avg_cos": 0.1501085925847292, + "avg_rho": 0.1694059126699964, + "avg_nudge": -0.05651436994473139, + "loss_term": 0.07073387503623962, + "loss_bridge": 0.015867799520492554 + }, + { + "step": 3000, + "avg_cos": 0.10612630782028039, + "avg_rho": 0.09910948442605634, + "avg_nudge": -0.034841354160259165, + "loss_term": 0.05772688612341881, + "loss_bridge": 0.0220709890127182 + }, + { + "step": 3500, + "avg_cos": 0.12348712608218193, + "avg_rho": 0.12745930068194866, + "avg_nudge": -0.04433920063699285, + "loss_term": 0.04150111600756645, + "loss_bridge": 0.014408521354198456 + }, + { + "step": 4000, + "avg_cos": 0.012287488003494218, + "avg_rho": 0.038615713633286454, + "avg_nudge": -0.006448045140132308, + "loss_term": 0.04848453402519226, + "loss_bridge": 0.026737889274954796 + }, + { + "step": 4500, + "avg_cos": 0.0344570055603981, + "avg_rho": 0.043244189505154886, + "avg_nudge": -0.012965732564528784, + "loss_term": 0.04390523582696915, + "loss_bridge": 0.016968518495559692 + }, + { + "step": 5000, + "avg_cos": -0.027601251068214577, + "avg_rho": -0.03011056105606258, + "avg_nudge": 0.013365049380809069, + "loss_term": 0.05419892817735672, + "loss_bridge": 0.029718749225139618 + } + ] + }, + "big_vnet": { + "best_cos": 0.25947993124524754, + "best_step": 500, + "final_cos": 0.012725223108039549, + "final_rho": -0.0029713623225688934, + "final_nudge": -0.0007455882926781973, + "history": [ + { + "step": 500, + "avg_cos": 0.25947993124524754, + "avg_rho": 0.2872927797337373, + "avg_nudge": -0.09759780826667945, + "loss_term": 0.24058431386947632, + "loss_bridge": 0.20110812783241272 + }, + { + "step": 1000, + "avg_cos": 0.11903308952848117, + "avg_rho": 0.10757205138603847, + "avg_nudge": -0.04069022353117665, + "loss_term": 0.1535366326570511, + "loss_bridge": 0.10731971263885498 + }, + { + "step": 1500, + "avg_cos": 0.04738504672423005, + "avg_rho": 0.04187516961246729, + "avg_nudge": -0.01667174060518543, + "loss_term": 0.12195796519517899, + "loss_bridge": 0.03276967257261276 + }, + { + "step": 2000, + "avg_cos": 0.05584627948701382, + "avg_rho": 0.06464844088380535, + "avg_nudge": -0.019484267104417086, + "loss_term": 0.07039390504360199, + "loss_bridge": 0.023653611540794373 + }, + { + "step": 2500, + "avg_cos": 0.11334512817362945, + "avg_rho": 0.13342145457863808, + "avg_nudge": -0.04256325137491027, + "loss_term": 0.09321287274360657, + "loss_bridge": 0.03603611886501312 + }, + { + "step": 3000, + "avg_cos": 0.07876436489944656, + "avg_rho": 0.08206061793801685, + "avg_nudge": -0.02565166230003039, + "loss_term": 0.03433217480778694, + "loss_bridge": 0.014776414260268211 + }, + { + "step": 3500, + "avg_cos": 0.059695989514390625, + "avg_rho": 0.043808821588754654, + "avg_nudge": -0.01833860871071617, + "loss_term": 0.07867280393838882, + "loss_bridge": 0.04651641845703125 + }, + { + "step": 4000, + "avg_cos": 0.008286381295571724, + "avg_rho": 0.01973852072842419, + "avg_nudge": -0.0019318210737158854, + "loss_term": 0.03502834588289261, + "loss_bridge": 0.011813998222351074 + }, + { + "step": 4500, + "avg_cos": -0.006770744492920737, + "avg_rho": 2.3505534045398235e-05, + "avg_nudge": 0.002897862965861956, + "loss_term": 0.04147114232182503, + "loss_bridge": 0.026934277266263962 + }, + { + "step": 5000, + "avg_cos": 0.012725223108039549, + "avg_rho": -0.0029713623225688934, + "avg_nudge": -0.0007455882926781973, + "loss_term": 0.038749027997255325, + "loss_bridge": 0.01941092312335968 + } + ] + }, + "ema_0.999": { + "best_cos": 0.10180055970946948, + "best_step": 1000, + "final_cos": -0.01584776126158734, + "final_rho": -0.01703926082700491, + "final_nudge": 0.007290713644276063, + "history": [ + { + "step": 500, + "avg_cos": -0.005628384804974, + "avg_rho": 0.010925033983464042, + "avg_nudge": 0.0015428058492640655, + "loss_term": 0.5920301675796509, + "loss_bridge": 1.4890536069869995 + }, + { + "step": 1000, + "avg_cos": 0.10180055970946948, + "avg_rho": 0.10515290250380833, + "avg_nudge": -0.037424925404290356, + "loss_term": 0.5715007185935974, + "loss_bridge": 0.4665977954864502 + }, + { + "step": 1500, + "avg_cos": 0.021814276037427287, + "avg_rho": 0.003451728650058309, + "avg_nudge": -0.005690907438596089, + "loss_term": 0.2616257071495056, + "loss_bridge": 0.2758212387561798 + }, + { + "step": 2000, + "avg_cos": 0.03669632730695108, + "avg_rho": 0.036930523036668696, + "avg_nudge": -0.01006293793519338, + "loss_term": 0.11164076626300812, + "loss_bridge": 0.1407940685749054 + }, + { + "step": 2500, + "avg_cos": -0.020325756592986483, + "avg_rho": -0.027848235641916592, + "avg_nudge": 0.011153894321372112, + "loss_term": 0.15471391379833221, + "loss_bridge": 0.06181420385837555 + }, + { + "step": 3000, + "avg_cos": -0.0060501456415901584, + "avg_rho": -0.013405103546877703, + "avg_nudge": 0.005228333951284488, + "loss_term": 0.07506504654884338, + "loss_bridge": 0.08207326382398605 + }, + { + "step": 3500, + "avg_cos": -0.02149865326161186, + "avg_rho": -0.023167532014970977, + "avg_nudge": 0.010358475303898254, + "loss_term": 0.048137497156858444, + "loss_bridge": 0.03276998922228813 + }, + { + "step": 4000, + "avg_cos": -0.007104064881180723, + "avg_rho": -0.00720199760204802, + "avg_nudge": 0.005406570465614398, + "loss_term": 0.03773114085197449, + "loss_bridge": 0.037960827350616455 + }, + { + "step": 4500, + "avg_cos": -0.0034141440119128674, + "avg_rho": -0.003708663280121982, + "avg_nudge": 0.002461720102777084, + "loss_term": 0.0416095145046711, + "loss_bridge": 0.03200625628232956 + }, + { + "step": 5000, + "avg_cos": -0.01584776126158734, + "avg_rho": -0.01703926082700491, + "avg_nudge": 0.007290713644276063, + "loss_term": 0.023291591554880142, + "loss_bridge": 0.026996765285730362 + } + ] + }, + "K16": { + "best_cos": 0.3187306026617686, + "best_step": 500, + "final_cos": 0.012039402422184745, + "final_rho": -0.002023946028202772, + "final_nudge": -0.0010173787983755271, + "history": [ + { + "step": 500, + "avg_cos": 0.3187306026617686, + "avg_rho": 0.3298306291302045, + "avg_nudge": -0.12337777391076088, + "loss_term": 0.3945310413837433, + "loss_bridge": 0.27078020572662354 + }, + { + "step": 1000, + "avg_cos": 0.15608959334592024, + "avg_rho": 0.14268888781468073, + "avg_nudge": -0.05713697926451763, + "loss_term": 0.13724187016487122, + "loss_bridge": 0.13912135362625122 + }, + { + "step": 1500, + "avg_cos": 0.07919560000300407, + "avg_rho": 0.08190769484887521, + "avg_nudge": -0.029933936273058254, + "loss_term": 0.08082282543182373, + "loss_bridge": 0.07666538655757904 + }, + { + "step": 2000, + "avg_cos": 0.06641554242620866, + "avg_rho": 0.055935436549286045, + "avg_nudge": -0.020354578581949074, + "loss_term": 0.07134468108415604, + "loss_bridge": 0.011993557214736938 + }, + { + "step": 2500, + "avg_cos": 0.0844428426741312, + "avg_rho": 0.0917752521733443, + "avg_nudge": -0.028687965202455718, + "loss_term": 0.0323462039232254, + "loss_bridge": 0.00667245127260685 + }, + { + "step": 3000, + "avg_cos": 0.0888429010907809, + "avg_rho": 0.06338833862294753, + "avg_nudge": -0.026075587142258883, + "loss_term": 0.04170331731438637, + "loss_bridge": 0.010068882256746292 + }, + { + "step": 3500, + "avg_cos": 0.016165781184099615, + "avg_rho": 0.014142975055923065, + "avg_nudge": -0.003436643397435546, + "loss_term": 0.03528498113155365, + "loss_bridge": 0.012309195473790169 + }, + { + "step": 4000, + "avg_cos": 0.06803990031282107, + "avg_rho": 0.055018783935035266, + "avg_nudge": -0.022906929564972717, + "loss_term": 0.027342472225427628, + "loss_bridge": 0.007336604408919811 + }, + { + "step": 4500, + "avg_cos": 0.024216643689821165, + "avg_rho": 0.04094697698019445, + "avg_nudge": -0.008364889615525803, + "loss_term": 0.027580715715885162, + "loss_bridge": 0.016118880361318588 + }, + { + "step": 5000, + "avg_cos": 0.012039402422184745, + "avg_rho": -0.002023946028202772, + "avg_nudge": -0.0010173787983755271, + "loss_term": 0.027855150401592255, + "loss_bridge": 0.010500052943825722 + } + ] + }, + "best_combo": { + "best_cos": 0.30479515840609867, + "best_step": 500, + "final_cos": -0.025737087552746136, + "final_rho": -0.01576789258979261, + "final_nudge": 0.011260819776604572, + "history": [ + { + "step": 500, + "avg_cos": 0.30479515840609867, + "avg_rho": 0.33246727536122006, + "avg_nudge": -0.11396304952601592, + "loss_term": 0.5129790306091309, + "loss_bridge": 0.21324321627616882 + }, + { + "step": 1000, + "avg_cos": 0.24110793943206468, + "avg_rho": 0.24976263443628946, + "avg_nudge": -0.08793549550076325, + "loss_term": 0.14881110191345215, + "loss_bridge": 0.1860560178756714 + }, + { + "step": 1500, + "avg_cos": 0.12106851922969024, + "avg_rho": 0.11747027436892192, + "avg_nudge": -0.042971268917123474, + "loss_term": 0.10358402132987976, + "loss_bridge": 0.11228226125240326 + }, + { + "step": 2000, + "avg_cos": 0.19137668733795485, + "avg_rho": 0.2001398652791977, + "avg_nudge": -0.06597264126564066, + "loss_term": 0.0836295336484909, + "loss_bridge": 0.03368356451392174 + }, + { + "step": 2500, + "avg_cos": 0.1382010855789607, + "avg_rho": 0.15255235826286176, + "avg_nudge": -0.05140705577408274, + "loss_term": 0.058304790407419205, + "loss_bridge": 0.014804087579250336 + }, + { + "step": 3000, + "avg_cos": 0.0815443117171526, + "avg_rho": 0.07958398557578523, + "avg_nudge": -0.026181443439175684, + "loss_term": 0.059965550899505615, + "loss_bridge": 0.016236742958426476 + }, + { + "step": 3500, + "avg_cos": 0.09519872162491083, + "avg_rho": 0.09258671143713097, + "avg_nudge": -0.03418003007148703, + "loss_term": 0.03956954926252365, + "loss_bridge": 0.014235305599868298 + }, + { + "step": 4000, + "avg_cos": -0.011209671385586262, + "avg_rho": -0.007080828654579818, + "avg_nudge": 0.004587161975602309, + "loss_term": 0.03683673217892647, + "loss_bridge": 0.018107034265995026 + }, + { + "step": 4500, + "avg_cos": 0.04608155476550261, + "avg_rho": 0.05846519426753124, + "avg_nudge": -0.017059470837314922, + "loss_term": 0.043314699083566666, + "loss_bridge": 0.023334285244345665 + }, + { + "step": 5000, + "avg_cos": -0.025737087552746136, + "avg_rho": -0.01576789258979261, + "avg_nudge": 0.011260819776604572, + "loss_term": 0.03279898688197136, + "loss_bridge": 0.015516506507992744 + } + ] + }, + "noise_1.0": { + "best_cos": 0.2831856335202853, + "best_step": 500, + "final_cos": 0.010971122809375325, + "final_rho": 0.014127362189659229, + "final_nudge": -0.002180874968568484, + "history": [ + { + "step": 500, + "avg_cos": 0.2831856335202853, + "avg_rho": 0.30480503539244336, + "avg_nudge": -0.1052918794254462, + "loss_term": 0.4858350455760956, + "loss_bridge": 0.20565161108970642 + }, + { + "step": 1000, + "avg_cos": 0.14922113033632436, + "avg_rho": 0.14803783098856607, + "avg_nudge": -0.052950371988117695, + "loss_term": 0.1246427372097969, + "loss_bridge": 0.1951574981212616 + }, + { + "step": 1500, + "avg_cos": 0.032916761857147016, + "avg_rho": 0.030339293957998354, + "avg_nudge": -0.011420852970331907, + "loss_term": 0.059441905468702316, + "loss_bridge": 0.09553220868110657 + }, + { + "step": 2000, + "avg_cos": 0.02735897192421059, + "avg_rho": 0.019464978327353794, + "avg_nudge": -0.008191200438886881, + "loss_term": 0.07519456744194031, + "loss_bridge": 0.017568862065672874 + }, + { + "step": 2500, + "avg_cos": 0.02141591941472143, + "avg_rho": 0.027642763530214626, + "avg_nudge": -0.007318320373694102, + "loss_term": 0.05027623474597931, + "loss_bridge": 0.014622311107814312 + }, + { + "step": 3000, + "avg_cos": 0.048597725903770574, + "avg_rho": 0.03585781451935569, + "avg_nudge": -0.015588213689625263, + "loss_term": 0.06457968056201935, + "loss_bridge": 0.01680811122059822 + }, + { + "step": 3500, + "avg_cos": 0.06355043267831206, + "avg_rho": 0.04236283013597131, + "avg_nudge": -0.021118728754421074, + "loss_term": 0.035970453172922134, + "loss_bridge": 0.015534179285168648 + }, + { + "step": 4000, + "avg_cos": 0.04567302499587337, + "avg_rho": 0.05423017560193936, + "avg_nudge": -0.01528523334612449, + "loss_term": 0.034011729061603546, + "loss_bridge": 0.009947280399501324 + }, + { + "step": 4500, + "avg_cos": 0.024254882785802085, + "avg_rho": 0.023739464891453583, + "avg_nudge": -0.009648310175786415, + "loss_term": 0.04320281371474266, + "loss_bridge": 0.01757601648569107 + }, + { + "step": 5000, + "avg_cos": 0.010971122809375325, + "avg_rho": 0.014127362189659229, + "avg_nudge": -0.002180874968568484, + "loss_term": 0.024934478104114532, + "loss_bridge": 0.011717695742845535 + } + ] + }, + "lr_3e-4": { + "best_cos": 0.6101815849542618, + "best_step": 500, + "final_cos": -0.008048781737064322, + "final_rho": -0.026018392760306597, + "final_nudge": 0.007571722225596507, + "history": [ + { + "step": 500, + "avg_cos": 0.6101815849542618, + "avg_rho": 0.620405301451683, + "avg_nudge": -0.2231019102036953, + "loss_term": 3.107595205307007, + "loss_bridge": 2.7781143188476562 + }, + { + "step": 1000, + "avg_cos": 0.39656415830055874, + "avg_rho": 0.39799897621075314, + "avg_nudge": -0.14455867062012354, + "loss_term": 0.29117679595947266, + "loss_bridge": 0.13909414410591125 + }, + { + "step": 1500, + "avg_cos": 0.26989879210789997, + "avg_rho": 0.2640038679043452, + "avg_nudge": -0.09908402090271314, + "loss_term": 0.1437155306339264, + "loss_bridge": 0.06845193356275558 + }, + { + "step": 2000, + "avg_cos": 0.15282577524582544, + "avg_rho": 0.1327554533878962, + "avg_nudge": -0.05070468131452799, + "loss_term": 0.10841675102710724, + "loss_bridge": 0.044324424117803574 + }, + { + "step": 2500, + "avg_cos": 0.054395756063361965, + "avg_rho": 0.04450509278103709, + "avg_nudge": -0.01604464929550886, + "loss_term": 0.09302366524934769, + "loss_bridge": 0.016296718269586563 + }, + { + "step": 3000, + "avg_cos": 0.041961303912103176, + "avg_rho": 0.032989607813457646, + "avg_nudge": -0.01202150775740544, + "loss_term": 0.06656567752361298, + "loss_bridge": 0.0089980224147439 + }, + { + "step": 3500, + "avg_cos": 0.010103868204168975, + "avg_rho": -0.01982816867530346, + "avg_nudge": 0.0005857348442077637, + "loss_term": 0.0479322224855423, + "loss_bridge": 0.004413220100104809 + }, + { + "step": 4000, + "avg_cos": 0.028387469239532948, + "avg_rho": 0.012580555553237597, + "avg_nudge": -0.006568876715997855, + "loss_term": 0.06536682695150375, + "loss_bridge": 0.005300351418554783 + }, + { + "step": 4500, + "avg_cos": 0.015142000513151288, + "avg_rho": 0.019022303090120356, + "avg_nudge": -0.004821705476691325, + "loss_term": 0.037237197160720825, + "loss_bridge": 0.005878066644072533 + }, + { + "step": 5000, + "avg_cos": -0.008048781737064322, + "avg_rho": -0.026018392760306597, + "avg_nudge": 0.007571722225596507, + "loss_term": 0.02460472844541073, + "loss_bridge": 0.005403056740760803 + } + ] + } +}
\ No newline at end of file diff --git a/results/toy_lq/toy_lq_seed42.json b/results/toy_lq/toy_lq_seed42.json new file mode 100644 index 0000000..0a821be --- /dev/null +++ b/results/toy_lq/toy_lq_seed42.json @@ -0,0 +1,335 @@ +{ + "config": { + "d_hidden": 64, + "output_dim": 10, + "num_layers": 12, + "sigma": 0.03, + "batch_size": 256, + "num_steps": 5000, + "lr_fb": 0.001, + "lam": 0.1, + "K": 8, + "ema_momentum": 0.995, + "sigma_bridge": 0.03, + "eval_every": 500, + "seed": 42, + "gpu": 0, + "output_dir": "results/toy_lq" + }, + "log": { + "steps": [ + 1, + 500, + 1000, + 1500, + 2000, + 2500, + 3000, + 3500, + 4000, + 4500, + 5000 + ], + "state_bridge_loss": [ + 66.01814270019531, + 2.058396816253662, + 2.191567897796631, + 2.1112077236175537, + 2.2748100757598877, + 2.1154427528381348, + 2.0553503036499023, + 1.966332197189331, + 2.1147494316101074, + 2.007577896118164, + 2.0389387607574463 + ], + "credit_bridge_loss": [ + 112.22884368896484, + 0.5559965372085571, + 0.36617255210876465, + 0.22275042533874512, + 0.13448813557624817, + 0.0859837457537651, + 0.06448937207460403, + 0.03930438682436943, + 0.033317387104034424, + 0.03704077750444412, + 0.02309737727046013 + ], + "dfa_costate_cos": [ + 0.0011988391992277824, + 0.007705975646296373, + -0.0006242827870524847, + 0.0037568132393062115, + 0.004209253507164808, + -0.0012397096635630499, + 0.0033803660816584644, + 0.003754911944270134, + 0.002183924035097544, + 0.0057398807973815845, + 0.0028941635615550554 + ], + "state_costate_cos": [ + 0.009303608121207127, + 0.9451924241506137, + 0.9427064611361577, + 0.9456643370481638, + 0.947175892499777, + 0.9486432488148029, + 0.9405259856810937, + 0.9429925313362708, + 0.9463646090947665, + 0.9434385620630704, + 0.9468063895518963 + ], + "credit_costate_cos": [ + 0.03412332414434506, + 0.3051199523302225, + 0.2641584941974053, + 0.17990898627501267, + 0.11904546274588658, + 0.03094297769264533, + 0.022915477577883463, + 0.008693795901938127, + 0.0027211602920523058, + 0.020937439484091904, + 0.033342053540624104 + ], + "dfa_rho": [ + 0.005011526596111556, + 0.0020135376447190842, + -0.011304221504057447, + 0.003935044708972176, + 0.0159942601264144, + -0.011545649264007807, + 0.01096861291443929, + 0.0007782066240906715, + -0.015019190264865756, + 0.007689292387415965, + 0.006176682732378443 + ], + "state_rho": [ + 0.011923154350370169, + 0.9337877084811529, + 0.9292550335327784, + 0.9341330577929815, + 0.9364036619663239, + 0.9322425921758016, + 0.9240961174170176, + 0.9329939633607864, + 0.9331585764884949, + 0.9324665367603302, + 0.9281178514162699 + ], + "credit_rho": [ + 0.031667908265565835, + 0.31275976697603863, + 0.2433429310719172, + 0.1816188059747219, + 0.11667139704028766, + 0.013198353117331862, + 0.022044080891646445, + -0.007547003333456814, + -0.011566813724736372, + -0.003948230994865298, + 0.016950203105807304 + ], + "dfa_nudge": [ + -0.0003799900102118651, + -0.0025626374408602715, + 0.0017628272374471028, + -0.001205168974896272, + -0.0011821148606638114, + 0.0014717701512078445, + -8.787959814071655e-05, + -0.0006076549955954155, + 0.0005303900688886642, + -0.0014991160326947768, + -0.0005284918782611688 + ], + "state_nudge": [ + -0.002327537008871635, + -0.34574924657742184, + -0.3358767156799634, + -0.33698513607184094, + -0.3561149264375369, + -0.33514803399642307, + -0.3557068184018135, + -0.3218521823485692, + -0.33325668424367905, + -0.34358637283245724, + -0.34090926001469296 + ], + "credit_nudge": [ + -0.014598140881086389, + -0.11370646270612876, + -0.09322128010292847, + -0.06459770910441875, + -0.04436610918492079, + -0.008244700108965239, + -0.00636714743450284, + 0.0006991980286935965, + 0.0036097665627797446, + -0.0035287897723416486, + -0.007594603579491377 + ], + "bridge_residual": [ + 0.06566914729773998, + 0.37236853316426277, + 0.3114783614873886, + 0.26258066420753795, + 0.20779911428689957, + 0.13781529137243828, + 0.10999641008675098, + 0.07969006771842639, + 0.06178054213523865, + 0.07378745886186759, + 0.0832565538585186 + ] + }, + "final_per_layer": { + "dfa_costate_cos": [ + -0.039167389273643494, + -0.0378018394112587, + 0.005690325051546097, + -0.023073989897966385, + -0.0005057593807578087, + -0.014485953375697136, + 0.03301015496253967, + 0.04401148855686188, + 0.054177843034267426, + 0.03981431573629379, + -0.04246171563863754, + -0.0012151142582297325, + 0.01963176019489765 + ], + "state_costate_cos": [ + 0.9444395303726196, + 0.9453534483909607, + 0.9460644721984863, + 0.9466040134429932, + 0.9469730257987976, + 0.9472954273223877, + 0.9476633667945862, + 0.9478192925453186, + 0.947782576084137, + 0.9473406672477722, + 0.947346568107605, + 0.9471475481987, + 0.9466531276702881 + ], + "credit_costate_cos": [ + 0.04752141237258911, + 0.04377397149801254, + 0.04051002860069275, + 0.03716123104095459, + 0.0350164994597435, + 0.03194836527109146, + 0.02999947965145111, + 0.02913709171116352, + 0.027683185413479805, + 0.0277146864682436, + 0.027703404426574707, + 0.027434173971414566, + 0.027843166142702103 + ], + "dfa_rho": [ + -0.04803554713726044, + 0.001050771214067936, + 0.008967258036136627, + -0.0271889790892601, + 0.02336559258401394, + -0.018210411071777344, + 0.05891512706875801, + 0.040720634162425995, + 0.07478035986423492, + 0.04802168905735016, + -0.06280035525560379, + -0.0254659466445446 + ], + "state_rho": [ + 0.9311305284500122, + 0.9222633838653564, + 0.9287852644920349, + 0.9287664890289307, + 0.9245603084564209, + 0.928197979927063, + 0.9275168180465698, + 0.9290561676025391, + 0.9267844557762146, + 0.927483081817627, + 0.9315245151519775, + 0.9313452243804932 + ], + "credit_rho": [ + 0.05740518122911453, + 0.035541512072086334, + 0.002091987058520317, + 0.024556485936045647, + -0.006993812508881092, + 0.03284040838479996, + 0.012268777936697006, + -0.004999782890081406, + 0.014774687588214874, + -0.010628825053572655, + 0.05940534919500351, + -0.012859531678259373 + ], + "dfa_nudge": [ + 0.01363457553088665, + 0.013758538290858269, + -0.0032786596566438675, + 0.010209780186414719, + -0.0013850200921297073, + 0.004463233053684235, + -0.012735363095998764, + -0.01801125332713127, + -0.019914839416742325, + -0.012350432574748993, + 0.017156170681118965, + 0.002111367881298065 + ], + "state_nudge": [ + -0.34207087755203247, + -0.3417494297027588, + -0.3413676619529724, + -0.3414176106452942, + -0.3412688076496124, + -0.34025871753692627, + -0.340381920337677, + -0.34054237604141235, + -0.3402785062789917, + -0.34061968326568604, + -0.34013575315475464, + -0.340819776058197 + ], + "credit_nudge": [ + -0.012597911059856415, + -0.011231275275349617, + -0.009987404569983482, + -0.008699672296643257, + -0.007977090775966644, + -0.006919408217072487, + -0.006156830117106438, + -0.005891043692827225, + -0.00546320341527462, + -0.005353953689336777, + -0.005423387512564659, + -0.005434062331914902 + ], + "bridge_residual": [ + 0.10028190910816193, + 0.09995798766613007, + 0.09832010418176651, + 0.09528136253356934, + 0.09090456366539001, + 0.0843343734741211, + 0.07823348790407181, + 0.0723348930478096, + 0.06821676343679428, + 0.06773027032613754, + 0.06943590939044952, + 0.0740470215678215 + ] + } +}
\ No newline at end of file diff --git a/results/toy_lq/toy_lq_v2_seed123_lam0.1_sig0.1_tgw1.0_fm0.0.json b/results/toy_lq/toy_lq_v2_seed123_lam0.1_sig0.1_tgw1.0_fm0.0.json new file mode 100644 index 0000000..56b3336 --- /dev/null +++ b/results/toy_lq/toy_lq_v2_seed123_lam0.1_sig0.1_tgw1.0_fm0.0.json @@ -0,0 +1,330 @@ +{ + "config": { + "d_hidden": 64, + "output_dim": 10, + "num_layers": 12, + "sigma": 0.03, + "batch_size": 256, + "num_steps": 8000, + "lr_fb": 0.001, + "lam": 0.1, + "K": 8, + "ema_momentum": 0.995, + "sigma_bridge": 0.1, + "eval_every": 1000, + "seed": 123, + "gpu": 0, + "output_dir": "results/toy_lq", + "vnet_hidden": 256, + "vnet_layers": 3, + "term_grad_weight": 1.0, + "fm_weight": 0.0 + }, + "log": { + "steps": [ + 1, + 1000, + 2000, + 3000, + 4000, + 5000, + 6000, + 7000, + 8000 + ], + "dfa_costate_cos": [ + 0.0061469420325011015, + 0.007492478704079986, + 0.006436596314112346, + 0.001648913836106658, + 0.005657717352733016, + 0.010142655577510595, + 0.005493079700196783, + 0.008208209726338586, + 0.002802101274331411 + ], + "state_costate_cos": [ + 0.04875442975511154, + 0.9345830877621969, + 0.9331425180037817, + 0.9344809154669443, + 0.9360056469837824, + 0.9349906196196874, + 0.9401087015867233, + 0.9377894500891367, + 0.9381228536367416 + ], + "credit_costate_cos": [ + 0.005350367398932576, + 0.8715064575274786, + 0.9082922885815302, + 0.9268463253974915, + 0.9348577807346979, + 0.9338823060194651, + 0.9398750513792038, + 0.9371217240889868, + 0.9403592944145203 + ], + "dfa_rho": [ + 0.014851124413932363, + 0.004832483362406492, + 0.005500619454930226, + 0.0014784028753638268, + 0.0024716570042073727, + 0.002679668522129456, + 0.004042171291075647, + 0.007841781480237842, + -0.003731151965136329 + ], + "state_rho": [ + 0.0525277191773057, + 0.9209283490975698, + 0.9212760378917059, + 0.9215241422255834, + 0.9252830098072687, + 0.9173514097929001, + 0.9268933484951655, + 0.9245105236768723, + 0.9267788628737131 + ], + "credit_rho": [ + 8.900166722014546e-05, + 0.8210234741369883, + 0.8804274102052053, + 0.9100336680809656, + 0.9228341629107794, + 0.9162226617336273, + 0.9262114216883978, + 0.9239083131154379, + 0.9273379941781362 + ], + "dfa_nudge": [ + -0.0020856610499322414, + -0.002391135785728693, + -0.0018826122395694256, + 0.0002794961134592692, + -0.0017084906188150246, + -0.0027558018919080496, + -0.0014085437481602032, + -0.0023310642379025617, + -0.0009137461893260479 + ], + "state_nudge": [ + -0.01762500188002984, + -0.31959830472866696, + -0.3143775438268979, + -0.3134472444653511, + -0.33237058420976, + -0.3183835695187251, + -0.32320864746967953, + -0.32294898976882297, + -0.31496328860521317 + ], + "credit_nudge": [ + -0.0004618208234508832, + -0.2996478999654452, + -0.30627985050280887, + -0.3099561383326848, + -0.33090290675560635, + -0.31673717498779297, + -0.3218600004911423, + -0.32163529098033905, + -0.31442063798507053 + ], + "bridge_residual": [], + "state_bridge_loss": [ + 64.66886901855469, + 1.9055249691009521, + 1.9802964925765991, + 1.9747117757797241, + 1.9333608150482178, + 2.0244460105895996, + 2.235288381576538, + 1.8483705520629883, + 1.833404541015625 + ], + "credit_bridge_loss": [ + 129.2601776123047, + 8.826760292053223, + 8.817267417907715, + 9.7413911819458, + 8.333325386047363, + 8.390682220458984, + 8.381990432739258, + 9.069635391235352, + 8.738546371459961 + ], + "term_loss": [ + 109.68403625488281, + 3.380112648010254, + 4.175432205200195, + 4.737654685974121, + 3.632157802581787, + 3.3351938724517822, + 3.776655912399292, + 4.416824817657471, + 4.120006084442139 + ], + "bridge_loss": [ + 5.943464884694549e-07, + 0.2433367669582367, + 0.18713834881782532, + 0.12417592853307724, + 0.13950209319591522, + 0.14192476868629456, + 0.15276893973350525, + 0.10663559287786484, + 0.12137635797262192 + ], + "term_grad_loss": [ + 19.57614517211914, + 5.203310966491699, + 4.4546966552734375, + 4.879560470581055, + 4.5616655349731445, + 4.9135637283325195, + 4.452565670013428, + 4.546175003051758, + 4.497163772583008 + ], + "fm_loss": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + "final_per_layer": { + "dfa_costate_cos": [ + 0.04479089006781578, + -0.04426354169845581, + 0.008874599821865559, + 0.05599997937679291, + 0.02961653470993042, + -0.022058574482798576, + 0.027186769992113113, + -0.0337681919336319, + -0.020245034247636795, + -0.04076787084341049, + 0.006424968130886555, + 0.021834686398506165 + ], + "state_costate_cos": [ + 0.9351104497909546, + 0.9362065196037292, + 0.9370632171630859, + 0.9374789595603943, + 0.9381210803985596, + 0.9385455846786499, + 0.9386894702911377, + 0.9390624165534973, + 0.9393105506896973, + 0.9393242597579956, + 0.9392982721328735, + 0.9392634630203247 + ], + "credit_costate_cos": [ + 0.9360430240631104, + 0.9369803667068481, + 0.9380120038986206, + 0.9385954141616821, + 0.9395070672035217, + 0.9403176307678223, + 0.9409258961677551, + 0.9413450956344604, + 0.9420279264450073, + 0.9428501725196838, + 0.9435971975326538, + 0.9441097378730774 + ], + "dfa_rho": [ + 0.041250549256801605, + -0.049739208072423935, + 0.00025176629424095154, + 0.04237007349729538, + 0.040798719972372055, + -0.037202395498752594, + 0.004211327061057091, + 0.0017704367637634277, + -0.04704931750893593, + -0.02992381900548935, + -0.01541070081293583, + 0.003898744471371174 + ], + "state_rho": [ + 0.9230573177337646, + 0.9292065501213074, + 0.9277745485305786, + 0.926990270614624, + 0.930223822593689, + 0.9262816905975342, + 0.9238564968109131, + 0.9235492944717407, + 0.927101731300354, + 0.929185152053833, + 0.9233707189559937, + 0.9307487607002258 + ], + "credit_rho": [ + 0.924521803855896, + 0.9213208556175232, + 0.923781156539917, + 0.9222273826599121, + 0.9230118989944458, + 0.9317750930786133, + 0.9238618016242981, + 0.9351733326911926, + 0.9263075590133667, + 0.9302359819412231, + 0.9358923435211182, + 0.9299467206001282 + ], + "dfa_nudge": [ + -0.01619502529501915, + 0.016632290557026863, + -0.004560360684990883, + -0.018955951556563377, + -0.011058392003178596, + 0.007517071440815926, + -0.008148249238729477, + 0.01091383583843708, + 0.006811931729316711, + 0.015017258003354073, + -0.000978708267211914, + -0.00796065479516983 + ], + "state_nudge": [ + -0.31427979469299316, + -0.3147982358932495, + -0.3148774206638336, + -0.3150269389152527, + -0.31536394357681274, + -0.3156891465187073, + -0.3149397373199463, + -0.3145085573196411, + -0.3144562840461731, + -0.31494998931884766, + -0.3152415156364441, + -0.31542789936065674 + ], + "credit_nudge": [ + -0.31250905990600586, + -0.3131006360054016, + -0.31331712007522583, + -0.31366780400276184, + -0.3142547011375427, + -0.3149237632751465, + -0.31442567706108093, + -0.314262330532074, + -0.31445154547691345, + -0.3153681755065918, + -0.3161371052265167, + -0.3166297376155853 + ] + } +}
\ No newline at end of file diff --git a/results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.0.json b/results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.0.json new file mode 100644 index 0000000..14050eb --- /dev/null +++ b/results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.0.json @@ -0,0 +1,458 @@ +{ + "config": { + "d_hidden": 64, + "output_dim": 10, + "num_layers": 12, + "sigma": 0.03, + "batch_size": 256, + "num_steps": 8000, + "lr_fb": 0.001, + "lam": 0.1, + "K": 8, + "ema_momentum": 0.995, + "sigma_bridge": 0.1, + "eval_every": 500, + "seed": 42, + "gpu": 0, + "output_dir": "results/toy_lq", + "vnet_hidden": 256, + "vnet_layers": 3, + "term_grad_weight": 1.0, + "fm_weight": 0.0 + }, + "log": { + "steps": [ + 1, + 500, + 1000, + 1500, + 2000, + 2500, + 3000, + 3500, + 4000, + 4500, + 5000, + 5500, + 6000, + 6500, + 7000, + 7500, + 8000 + ], + "dfa_costate_cos": [ + 0.001022340264171362, + 0.00324871806272616, + 0.003955680275491129, + -6.335701133745412e-05, + 0.0009047117297692845, + -0.0015635235855976741, + 0.0050255006256823736, + 0.003974012487257521, + -0.0024126653637116155, + 0.0010036304011009634, + -0.0006338046708454689, + 0.00017608713824301958, + -9.714129070440929e-05, + -0.0011926805212472875, + -0.0026168684319903455, + -0.003671417556082209, + 0.0016770829679444432 + ], + "state_costate_cos": [ + 0.009337988216429949, + 0.943408285578092, + 0.9517824550469717, + 0.9448912839094797, + 0.9432462950547537, + 0.9444176405668259, + 0.9435405482848486, + 0.9443136354287466, + 0.9458288550376892, + 0.9465046127637228, + 0.9465998113155365, + 0.9461832990248998, + 0.9424780905246735, + 0.9474884221951166, + 0.9449383119742075, + 0.9425351719061533, + 0.9405365288257599 + ], + "credit_costate_cos": [ + 0.024892715892444055, + 0.8234576731920242, + 0.8861371924479803, + 0.8940252065658569, + 0.9119344304005305, + 0.928387979666392, + 0.9261951943238577, + 0.9374307443698248, + 0.939670259753863, + 0.9423713783423106, + 0.9447299987077713, + 0.944857731461525, + 0.9409955541292826, + 0.9459459533294042, + 0.9439971546332041, + 0.9420813073714575, + 0.9400773843129476 + ], + "dfa_rho": [ + 0.015879416760678094, + 0.004265802912414074, + 0.008714484671751658, + 0.001407407767449816, + 0.007156838779337704, + -0.0010706717148423195, + -0.00500367038572828, + 0.00037602245962868136, + 0.00797194141584138, + 0.012475600582547486, + 0.006475616673318048, + 0.006899521841357152, + 0.008204833992446462, + -0.0002780493038396041, + -0.009471656677002708, + -0.00781721225939691, + 0.025706250220537186 + ], + "state_rho": [ + 0.0029325426245729127, + 0.9265175064404806, + 0.938806007305781, + 0.9280826350053152, + 0.9287765026092529, + 0.9266663392384847, + 0.9301863809426626, + 0.9323704888423284, + 0.9297124246756235, + 0.9302101731300354, + 0.9340693553288778, + 0.9324037134647369, + 0.9265123655398687, + 0.9321505973736445, + 0.927787164847056, + 0.9323761165142059, + 0.9224280416965485 + ], + "credit_rho": [ + 0.02234963719577839, + 0.7719902147849401, + 0.8432890474796295, + 0.8447659462690353, + 0.877406562368075, + 0.8958166440327963, + 0.9122767845789591, + 0.9254603485266367, + 0.9236042300860087, + 0.9263056516647339, + 0.9323819329341253, + 0.9314924577871958, + 0.9266939163208008, + 0.929329847296079, + 0.9257104198137919, + 0.9289217789967855, + 0.9269137680530548 + ], + "dfa_nudge": [ + -0.0003799900102118651, + -0.0009698765352368355, + -0.0013957968913018703, + 0.00035563452790180844, + -0.0003614289841304223, + 0.0008372832089662552, + -0.0013908503266672294, + -0.0013163429684937, + 0.0005854369762043158, + -0.000560786963130037, + 0.00015371766251822314, + -1.0695696497956911e-05, + 2.807355485856533e-05, + 5.036763225992521e-05, + 0.0008217894161740938, + 0.0017250357971837123, + -0.0004334271264572938 + ], + "state_nudge": [ + -0.002327537008871635, + -0.34193428109089535, + -0.340603639682134, + -0.35042013972997665, + -0.3389856591820717, + -0.34554000198841095, + -0.3481511374314626, + -0.3558509051799774, + -0.3337947155038516, + -0.33715616663297016, + -0.33583804468313855, + -0.35106155276298523, + -0.34434546530246735, + -0.34641525397698086, + -0.3329972525437673, + -0.34870391835769016, + -0.3395152688026428 + ], + "credit_nudge": [ + -0.0079942528779308, + -0.30343081553777057, + -0.31936119496822357, + -0.33327333877484006, + -0.3284662067890167, + -0.3394550507267316, + -0.3424832498033841, + -0.3524467721581459, + -0.33071904132763547, + -0.3346964443723361, + -0.3337800477941831, + -0.34932391593853634, + -0.34230031818151474, + -0.3444634775320689, + -0.3310972551504771, + -0.34694332132736844, + -0.3377286195755005 + ], + "bridge_residual": [], + "state_bridge_loss": [ + 66.01814270019531, + 2.0012869834899902, + 2.1027088165283203, + 2.1019272804260254, + 2.0727572441101074, + 1.9770110845565796, + 2.1761908531188965, + 2.094480514526367, + 1.9725749492645264, + 2.0142102241516113, + 2.0340821743011475, + 1.9380583763122559, + 1.989743947982788, + 2.4057328701019287, + 2.19437575340271, + 2.1816155910491943, + 2.0803794860839844 + ], + "credit_bridge_loss": [ + 132.09298706054688, + 11.307573318481445, + 9.575517654418945, + 8.768391609191895, + 8.752981185913086, + 9.271968841552734, + 8.606213569641113, + 8.753968238830566, + 8.223488807678223, + 9.935165405273438, + 9.02967357635498, + 8.346063613891602, + 9.022041320800781, + 8.36446762084961, + 8.635725021362305, + 8.760185241699219, + 8.503408432006836 + ], + "term_loss": [ + 111.63633728027344, + 4.978545188903809, + 3.8962953090667725, + 3.81073260307312, + 4.386394500732422, + 4.507748603820801, + 3.6740365028381348, + 3.7450127601623535, + 3.5060954093933105, + 4.545898914337158, + 4.322302341461182, + 3.594371795654297, + 4.038668632507324, + 4.025404453277588, + 3.9016177654266357, + 3.7638583183288574, + 3.807260036468506 + ], + "bridge_loss": [ + 6.45359421014291e-07, + 0.432157039642334, + 0.20619139075279236, + 0.24715952575206757, + 0.1523856669664383, + 0.12624874711036682, + 0.13425695896148682, + 0.16367560625076294, + 0.14530189335346222, + 0.18452416360378265, + 0.11599670350551605, + 0.11983858048915863, + 0.11901542544364929, + 0.16489851474761963, + 0.16058529913425446, + 0.09414370357990265, + 0.12402483820915222 + ], + "term_grad_loss": [ + 20.456655502319336, + 5.8968706130981445, + 5.4730305671691895, + 4.710499286651611, + 4.214200496673584, + 4.637970924377441, + 4.797920227050781, + 4.845280170440674, + 4.572091579437256, + 5.204742908477783, + 4.591374397277832, + 4.631853103637695, + 4.864356994628906, + 4.174164772033691, + 4.573522090911865, + 4.9021830558776855, + 4.5721235275268555 + ], + "fm_loss": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + "final_per_layer": { + "dfa_costate_cos": [ + -0.03974368795752525, + -0.03234883397817612, + -0.01280729565769434, + -0.03282422572374344, + 0.005050511099398136, + -0.01393066719174385, + 0.036346256732940674, + 0.04377925395965576, + 0.06324614584445953, + 0.05324942618608475, + -0.04297419637441635, + -0.006917691323906183 + ], + "state_costate_cos": [ + 0.9383341073989868, + 0.9390445947647095, + 0.9393846988677979, + 0.9396730661392212, + 0.9402315616607666, + 0.940619707107544, + 0.9413514137268066, + 0.9420933723449707, + 0.941848874092102, + 0.9416857957839966, + 0.9411274194717407, + 0.9410437345504761 + ], + "credit_costate_cos": [ + 0.937440037727356, + 0.9379364252090454, + 0.9384041428565979, + 0.9385387897491455, + 0.9392049312591553, + 0.9398118257522583, + 0.9403334259986877, + 0.9412999153137207, + 0.94156414270401, + 0.9419326186180115, + 0.9419490098953247, + 0.9425133466720581 + ], + "dfa_rho": [ + -0.023802103474736214, + -0.039411380887031555, + -0.008970402181148529, + -0.0021191751584410667, + 0.04573667049407959, + 0.010564171709120274, + 0.04995737224817276, + 0.10094872862100601, + 0.07801118493080139, + 0.07188688963651657, + -0.02006501331925392, + 0.045738060027360916 + ], + "state_rho": [ + 0.9204449653625488, + 0.917449951171875, + 0.9193954467773438, + 0.9267550706863403, + 0.9280773401260376, + 0.9225738048553467, + 0.9190236926078796, + 0.924782931804657, + 0.9248216152191162, + 0.9217315912246704, + 0.9215092658996582, + 0.9225708246231079 + ], + "credit_rho": [ + 0.9222818613052368, + 0.923028826713562, + 0.924055814743042, + 0.9263180494308472, + 0.9288915991783142, + 0.922785758972168, + 0.9231191873550415, + 0.9254114627838135, + 0.9321990013122559, + 0.9311563372612, + 0.934211015701294, + 0.9295063018798828 + ], + "dfa_nudge": [ + 0.014356113970279694, + 0.013086749240756035, + 0.004036703146994114, + 0.011657552793622017, + -0.0023863790556788445, + 0.006091207265853882, + -0.013204541988670826, + -0.016102034598588943, + -0.022989045828580856, + -0.018843289464712143, + 0.015521062538027763, + 0.0035747764632105827 + ], + "state_nudge": [ + -0.34063079953193665, + -0.34022200107574463, + -0.33977580070495605, + -0.3401387929916382, + -0.34019535779953003, + -0.33964985609054565, + -0.3392874002456665, + -0.33898109197616577, + -0.3392573297023773, + -0.3390581011772156, + -0.33809930086135864, + -0.33888739347457886 + ], + "credit_nudge": [ + -0.3375471234321594, + -0.33736705780029297, + -0.33710652589797974, + -0.33756691217422485, + -0.3378610908985138, + -0.33758991956710815, + -0.33741605281829834, + -0.3374325633049011, + -0.33806926012039185, + -0.33822208642959595, + -0.33765909075737, + -0.3389057517051697 + ] + } +}
\ No newline at end of file diff --git a/results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.1.json b/results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.1.json new file mode 100644 index 0000000..2977e98 --- /dev/null +++ b/results/toy_lq/toy_lq_v2_seed42_lam0.1_sig0.1_tgw1.0_fm0.1.json @@ -0,0 +1,330 @@ +{ + "config": { + "d_hidden": 64, + "output_dim": 10, + "num_layers": 12, + "sigma": 0.03, + "batch_size": 256, + "num_steps": 8000, + "lr_fb": 0.001, + "lam": 0.1, + "K": 8, + "ema_momentum": 0.995, + "sigma_bridge": 0.1, + "eval_every": 1000, + "seed": 42, + "gpu": 0, + "output_dir": "results/toy_lq", + "vnet_hidden": 256, + "vnet_layers": 3, + "term_grad_weight": 1.0, + "fm_weight": 0.1 + }, + "log": { + "steps": [ + 1, + 1000, + 2000, + 3000, + 4000, + 5000, + 6000, + 7000, + 8000 + ], + "dfa_costate_cos": [ + -0.003210080186060319, + 0.006517285481095314, + 0.002081584728633364, + 0.002482607301014165, + 0.0003743169557613631, + -0.0018058015266433358, + 0.005970992535973589, + 0.0013112322388527293, + -0.00010686229992037018 + ], + "state_costate_cos": [ + -0.008148168989767631, + 0.942050834496816, + 0.9447367936372757, + 0.9457606921593348, + 0.9468860725561777, + 0.9432125687599182, + 0.9417577485243479, + 0.945394163330396, + 0.9473433345556259 + ], + "credit_costate_cos": [ + -0.011870964197441936, + 0.8804089923699697, + 0.9169842700163523, + 0.9352772484223048, + 0.9427581528822581, + 0.9391989062229792, + 0.9398341725269953, + 0.9456242968638738, + 0.9460541109244028 + ], + "dfa_rho": [ + -0.012772266054525971, + 0.015942092907304566, + 0.008943260026474794, + 0.004492536109561722, + -0.007791806012392044, + -0.007830069400370121, + 0.014392409706488252, + 0.0063066319562494755, + -0.001311147507900993 + ], + "state_rho": [ + -0.0018634579222028453, + 0.9263874938090643, + 0.9358506997426351, + 0.9348766555388769, + 0.9348586251338323, + 0.9323162386814753, + 0.9287882298231125, + 0.9243461688359579, + 0.9346484492222468 + ], + "credit_rho": [ + -0.023325924955618877, + 0.8106876164674759, + 0.8797721316417059, + 0.9208102275927862, + 0.9290666033824285, + 0.9226995905240377, + 0.9258128056923548, + 0.9266915867726008, + 0.9364602218071619 + ], + "dfa_nudge": [ + 0.0015840742271393538, + -0.0023392424918711185, + -0.0002702907659113407, + -0.0010186488119264443, + -0.0001735797462364038, + 0.000983052421361208, + -0.0020705487113445997, + -0.0009097627674539884, + 0.0002721068449318409 + ], + "state_nudge": [ + 0.0024757004963854947, + -0.33942782630523044, + -0.35110870252052945, + -0.35064691056807834, + -0.3403419057528178, + -0.3532385254899661, + -0.33996084084113437, + -0.35185040285189945, + -0.3480343023935954 + ], + "credit_nudge": [ + 0.006069206205817561, + -0.31947339574495953, + -0.3404841795563698, + -0.3458903282880783, + -0.3376796667774518, + -0.3504715636372566, + -0.3379351521531741, + -0.3500200683871905, + -0.3462969238559405 + ], + "bridge_residual": [], + "state_bridge_loss": [ + 66.01814270019531, + 2.341611623764038, + 1.9537389278411865, + 1.9830116033554077, + 2.0491604804992676, + 2.0490386486053467, + 2.1182405948638916, + 2.373213291168213, + 1.9122812747955322 + ], + "credit_bridge_loss": [ + 132.09298706054688, + 9.123503684997559, + 8.516526222229004, + 8.634014129638672, + 8.720410346984863, + 8.43734359741211, + 10.247673034667969, + 8.351385116577148, + 8.474419593811035 + ], + "term_loss": [ + 111.63633728027344, + 4.155745029449463, + 3.8754897117614746, + 4.040826320648193, + 4.010752201080322, + 4.2646074295043945, + 5.391899108886719, + 3.27858829498291, + 3.74959135055542 + ], + "bridge_loss": [ + 6.45359421014291e-07, + 0.22446846961975098, + 0.18948684632778168, + 0.14564603567123413, + 0.1301368772983551, + 0.1614246517419815, + 0.19166265428066254, + 0.19554881751537323, + 0.15442883968353271 + ], + "term_grad_loss": [ + 20.456655502319336, + 4.743132591247559, + 4.45129919052124, + 4.447327136993408, + 4.579296112060547, + 4.011109352111816, + 4.663890838623047, + 4.877087593078613, + 4.57021427154541 + ], + "fm_loss": [ + 1.4957863925246784e-07, + 0.001574160298332572, + 0.0025115222670137882, + 0.002142899436876178, + 0.0022483705542981625, + 0.0020201210863888264, + 0.00219835271127522, + 0.0016042720526456833, + 0.0018535356502979994 + ] + }, + "final_per_layer": { + "dfa_costate_cos": [ + -0.04993891716003418, + -0.03484980762004852, + 0.008264334872364998, + -0.018589604645967484, + -0.019490346312522888, + 0.0018934037070721388, + 0.025425557047128677, + 0.046417269855737686, + 0.06304077804088593, + 0.03301604464650154, + -0.06211906671524048, + 0.005648006685078144 + ], + "state_costate_cos": [ + 0.9454483985900879, + 0.946227490901947, + 0.9468960165977478, + 0.9470230340957642, + 0.9475629925727844, + 0.9478208422660828, + 0.9478966593742371, + 0.9480605125427246, + 0.9479074478149414, + 0.9478251934051514, + 0.9478766918182373, + 0.9475747346878052 + ], + "credit_costate_cos": [ + 0.9432812333106995, + 0.9438751935958862, + 0.9443729519844055, + 0.9447157382965088, + 0.9455237984657288, + 0.9461240768432617, + 0.9463104009628296, + 0.9467931985855103, + 0.9471821784973145, + 0.9476308226585388, + 0.9481172561645508, + 0.9487224817276001 + ], + "dfa_rho": [ + -0.0704549178481102, + 0.031227122992277145, + 0.009932642802596092, + -0.015737976878881454, + -0.045219600200653076, + 0.013525029644370079, + 0.03365146368741989, + 0.04508150368928909, + 0.05846859887242317, + 7.291417568922043e-05, + -0.06866942346096039, + -0.007611127570271492 + ], + "state_rho": [ + 0.935008704662323, + 0.9310771822929382, + 0.9295646548271179, + 0.938683032989502, + 0.9329910278320312, + 0.9313007593154907, + 0.9357653260231018, + 0.9376221895217896, + 0.9326146841049194, + 0.9372754693031311, + 0.9379282593727112, + 0.9359501004219055 + ], + "credit_rho": [ + 0.9306489825248718, + 0.9282867312431335, + 0.9364583492279053, + 0.9361423850059509, + 0.9323122501373291, + 0.9380875825881958, + 0.93974369764328, + 0.9351372122764587, + 0.941116213798523, + 0.9382349252700806, + 0.9381735324859619, + 0.9431807994842529 + ], + "dfa_nudge": [ + 0.020365234464406967, + 0.014575351029634476, + -0.002057630568742752, + 0.006316322833299637, + 0.007856002077460289, + 0.0004078727215528488, + -0.010462434962391853, + -0.017474167048931122, + -0.02493620105087757, + -0.011619189754128456, + 0.022399690002202988, + -0.00210556760430336 + ], + "state_nudge": [ + -0.34937936067581177, + -0.3488026261329651, + -0.3481972813606262, + -0.34827157855033875, + -0.3484804630279541, + -0.3478168845176697, + -0.34806668758392334, + -0.3481366038322449, + -0.3481329381465912, + -0.3476967513561249, + -0.3463967740535736, + -0.34703367948532104 + ], + "credit_nudge": [ + -0.34637150168418884, + -0.345956414937973, + -0.3455013632774353, + -0.34578341245651245, + -0.3462638854980469, + -0.3458409905433655, + -0.3462728261947632, + -0.3466273546218872, + -0.3469495475292206, + -0.34689074754714966, + -0.3459630608558655, + -0.34714198112487793 + ] + } +}
\ No newline at end of file diff --git a/results/toy_lq/toy_lq_v2_seed42_lam1.0_sig0.3_tgw0.0_fm0.0.json b/results/toy_lq/toy_lq_v2_seed42_lam1.0_sig0.3_tgw0.0_fm0.0.json new file mode 100644 index 0000000..78451b9 --- /dev/null +++ b/results/toy_lq/toy_lq_v2_seed42_lam1.0_sig0.3_tgw0.0_fm0.0.json @@ -0,0 +1,282 @@ +{ + "config": { + "d_hidden": 64, + "output_dim": 10, + "num_layers": 12, + "sigma": 0.03, + "batch_size": 256, + "num_steps": 5000, + "lr_fb": 0.001, + "lam": 1.0, + "K": 8, + "ema_momentum": 0.995, + "sigma_bridge": 0.3, + "eval_every": 1000, + "seed": 42, + "gpu": 0, + "output_dir": "results/toy_lq", + "vnet_hidden": 256, + "vnet_layers": 3, + "term_grad_weight": 0.0, + "fm_weight": 0.0 + }, + "log": { + "steps": [ + 1, + 1000, + 2000, + 3000, + 4000, + 5000 + ], + "dfa_costate_cos": [ + 0.001022340264171362, + 0.0024294707691296935, + 0.000613357910575966, + 0.002641987521201372, + 0.0019003628791930776, + 0.004648004221962765 + ], + "state_costate_cos": [ + 0.009337988216429949, + 0.9437081764141718, + 0.9386202742656072, + 0.9438605606555939, + 0.9475679496924082, + 0.9496726940075556 + ], + "credit_costate_cos": [ + 0.023481125943362713, + 0.0021192015459140143, + 0.05908452393487096, + 0.11570050132771333, + 0.09644180163741112, + 0.06304322431484859 + ], + "dfa_rho": [ + 0.015879416760678094, + 0.009290086299491426, + 0.0009949249991526206, + -0.004670841522359599, + -0.0029721508423487344, + 0.0010571565168599288 + ], + "state_rho": [ + 0.0029325426245729127, + 0.9371241927146912, + 0.9219773809115092, + 0.9298903445402781, + 0.9345368842283884, + 0.9332165767749151 + ], + "credit_rho": [ + 0.02086908878603329, + -0.014479975526531538, + 0.04267269264285763, + 0.10674913817395766, + 0.091526560485363, + 0.04765695089008659 + ], + "dfa_nudge": [ + -0.0003799900102118651, + -0.0008909914953013262, + 0.00031573620314399403, + -0.0008827850688248873, + -0.0003006396194299062, + -0.0016971436173965533 + ], + "state_nudge": [ + -0.002327537008871635, + -0.33619146794080734, + -0.3439306889971097, + -0.32351043323675793, + -0.33487510432799655, + -0.35304194688796997 + ], + "credit_nudge": [ + -0.007470574385176103, + 0.002750888311614593, + -0.017584003585701186, + -0.037241545505821705, + -0.03454847944279512, + -0.02138534157226483 + ], + "bridge_residual": [], + "state_bridge_loss": [ + 66.01814270019531, + 2.143366813659668, + 1.9674744606018066, + 2.152421712875366, + 2.0728020668029785, + 2.0048255920410156 + ], + "credit_bridge_loss": [ + 111.63633728027344, + 0.35802197456359863, + 0.10368431359529495, + 0.12013768404722214, + 0.07032017409801483, + 0.051539346575737 + ], + "term_loss": [ + 111.63633728027344, + 0.2640027403831482, + 0.05911973863840103, + 0.06323631852865219, + 0.03984691575169563, + 0.03099522553384304 + ], + "bridge_loss": [ + 3.0909704946679994e-06, + 0.09401924908161163, + 0.04456457495689392, + 0.056901365518569946, + 0.030473260208964348, + 0.02054412104189396 + ], + "term_grad_loss": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "fm_loss": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + "final_per_layer": { + "dfa_costate_cos": [ + -0.025100916624069214, + -0.025588383898139, + 0.0002230944810435176, + -0.012649456970393658, + -0.003438096959143877, + -0.008626529946923256, + 0.04746557027101517, + 0.0495537593960762, + 0.05357586592435837, + 0.039518602192401886, + -0.06593793630599976, + 0.0067804791033267975 + ], + "state_costate_cos": [ + 0.94776451587677, + 0.9486032724380493, + 0.9489368200302124, + 0.9489387273788452, + 0.9493762850761414, + 0.9498711228370667, + 0.9505667686462402, + 0.9509860277175903, + 0.9503088593482971, + 0.9503939151763916, + 0.9501732587814331, + 0.9501527547836304 + ], + "credit_costate_cos": [ + 0.06872019171714783, + 0.0653301253914833, + 0.06267654895782471, + 0.060872238129377365, + 0.059321627020835876, + 0.05934217572212219, + 0.06014599651098251, + 0.05938819795846939, + 0.06065107509493828, + 0.06321083009243011, + 0.07017489522695541, + 0.066684789955616 + ], + "dfa_rho": [ + -0.039396628737449646, + -0.062005721032619476, + 0.005523890256881714, + 0.014756813645362854, + -0.029910940676927567, + 0.016373179852962494, + 0.06027424708008766, + 0.07274787873029709, + 0.04506715014576912, + 0.02043943479657173, + -0.07323547452688217, + -0.017947951331734657 + ], + "state_rho": [ + 0.9350653886795044, + 0.9352835416793823, + 0.9357374310493469, + 0.934054434299469, + 0.9226828813552856, + 0.930233359336853, + 0.9396862983703613, + 0.9324671030044556, + 0.9295884966850281, + 0.9347179532051086, + 0.9344969987869263, + 0.9345850348472595 + ], + "credit_rho": [ + 0.07084883749485016, + 0.0004139472730457783, + 0.08371011912822723, + 0.010933063924312592, + 0.07074157148599625, + 0.02977888286113739, + 0.06011004000902176, + 0.020937703549861908, + 0.067134790122509, + 0.04539356008172035, + 0.06114630401134491, + 0.050734590739011765 + ], + "dfa_nudge": [ + 0.010068551637232304, + 0.009215200319886208, + -0.002217007800936699, + 0.007142472080886364, + 0.0020765019580721855, + 0.004165132530033588, + -0.019989464432001114, + -0.018945707008242607, + -0.01947549544274807, + -0.014030318707227707, + 0.023684613406658173, + -0.0020602019503712654 + ], + "state_nudge": [ + -0.3548189401626587, + -0.3539973497390747, + -0.3535424768924713, + -0.35353928804397583, + -0.35366636514663696, + -0.35295385122299194, + -0.35282424092292786, + -0.35278239846229553, + -0.3528212308883667, + -0.3526480793952942, + -0.35135188698768616, + -0.35155725479125977 + ], + "credit_nudge": [ + -0.024559948593378067, + -0.023154649883508682, + -0.02181980386376381, + -0.020881079137325287, + -0.02010848931968212, + -0.019884146749973297, + -0.020017635077238083, + -0.01966693066060543, + -0.020179908722639084, + -0.02114756405353546, + -0.023488853126764297, + -0.021715089678764343 + ] + } +}
\ No newline at end of file diff --git a/results/toy_lq/toy_lq_v2_seed456_lam0.1_sig0.1_tgw1.0_fm0.0.json b/results/toy_lq/toy_lq_v2_seed456_lam0.1_sig0.1_tgw1.0_fm0.0.json new file mode 100644 index 0000000..8e7c7b0 --- /dev/null +++ b/results/toy_lq/toy_lq_v2_seed456_lam0.1_sig0.1_tgw1.0_fm0.0.json @@ -0,0 +1,330 @@ +{ + "config": { + "d_hidden": 64, + "output_dim": 10, + "num_layers": 12, + "sigma": 0.03, + "batch_size": 256, + "num_steps": 8000, + "lr_fb": 0.001, + "lam": 0.1, + "K": 8, + "ema_momentum": 0.995, + "sigma_bridge": 0.1, + "eval_every": 1000, + "seed": 456, + "gpu": 0, + "output_dir": "results/toy_lq", + "vnet_hidden": 256, + "vnet_layers": 3, + "term_grad_weight": 1.0, + "fm_weight": 0.0 + }, + "log": { + "steps": [ + 1, + 1000, + 2000, + 3000, + 4000, + 5000, + 6000, + 7000, + 8000 + ], + "dfa_costate_cos": [ + -0.008305357536301017, + -0.0011448257913192113, + -0.011490714969113469, + -0.005118173896335065, + -0.0028971168600643673, + -0.007146042305976152, + -0.005135039333254099, + -0.006408803591815134, + 0.0030692683843274913 + ], + "state_costate_cos": [ + 0.010766413528472185, + 0.9454400340716044, + 0.9422353406747183, + 0.945327232281367, + 0.9411078443129858, + 0.9500264773766199, + 0.947187085946401, + 0.9483370830615362, + 0.9451590329408646 + ], + "credit_costate_cos": [ + 0.010942678588132063, + 0.8771380881468455, + 0.9226242254177729, + 0.9362819741169611, + 0.9372822741667429, + 0.9464249561230341, + 0.9468309829632441, + 0.9452949613332748, + 0.9439582029978434 + ], + "dfa_rho": [ + -0.0028248391657446823, + -0.001316564545656244, + -0.00823917348558704, + -0.008288809360237792, + -0.009200356512640914, + -0.02464000484906137, + 0.006656331475824118, + -0.014378171879798174, + 0.00816792449525868 + ], + "state_rho": [ + 0.02655455912463367, + 0.9340365380048752, + 0.930340642730395, + 0.9328931520382563, + 0.9261527856191, + 0.9353549828131994, + 0.9365303913752238, + 0.930912658572197, + 0.9325708796580633 + ], + "credit_rho": [ + 0.015292729716748, + 0.8341556191444397, + 0.884218767285347, + 0.9175956894954046, + 0.9214809189240137, + 0.9294924139976501, + 0.9364803830782572, + 0.9272431135177612, + 0.9332626809676489 + ], + "dfa_nudge": [ + 0.004384364855165283, + 0.0014555706487347682, + 0.005660574107120435, + 0.003231095770994822, + 0.0024609332904219627, + 0.0033611954810718694, + 0.0025722047624488673, + 0.003350513521581888, + -0.00046457063096265 + ], + "state_nudge": [ + -0.005116054400180777, + -0.36389906456073123, + -0.37967536846796673, + -0.3594902977347374, + -0.3901909242073695, + -0.3562925284107526, + -0.34591514120499295, + -0.3550695503751437, + -0.3515613650282224 + ], + "credit_nudge": [ + -0.003232262640570601, + -0.34101804345846176, + -0.37139702836672467, + -0.355108546713988, + -0.38709117472171783, + -0.3536044582724571, + -0.34420712540547055, + -0.35209985077381134, + -0.3494671831528346 + ], + "bridge_residual": [], + "state_bridge_loss": [ + 66.31975555419922, + 2.3091068267822266, + 2.0314407348632812, + 2.110802412033081, + 1.9804980754852295, + 1.8128150701522827, + 2.0147881507873535, + 2.169416904449463, + 2.0117833614349365 + ], + "credit_bridge_loss": [ + 158.73072814941406, + 11.18570613861084, + 9.38658332824707, + 9.429727554321289, + 10.842954635620117, + 10.344818115234375, + 10.250753402709961, + 10.136574745178223, + 9.871820449829102 + ], + "term_loss": [ + 132.93673706054688, + 4.870186805725098, + 3.9316887855529785, + 4.335302829742432, + 5.871437072753906, + 4.5994768142700195, + 5.077899932861328, + 4.8271284103393555, + 4.638625621795654 + ], + "bridge_loss": [ + 7.166463547036983e-07, + 0.3530547022819519, + 0.20790576934814453, + 0.15763649344444275, + 0.18163490295410156, + 0.154127299785614, + 0.1578863561153412, + 0.13594242930412292, + 0.12283627688884735 + ], + "term_grad_loss": [ + 25.793991088867188, + 5.962464809417725, + 5.246988296508789, + 4.936788082122803, + 4.789882183074951, + 5.591214656829834, + 5.0149664878845215, + 5.173503875732422, + 5.110358238220215 + ], + "fm_loss": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + "final_per_layer": { + "dfa_costate_cos": [ + 0.0057477825321257114, + -0.05385865271091461, + 0.031865090131759644, + -0.0450170561671257, + 0.01222146674990654, + 0.048112209886312485, + -0.0029807062819600105, + 0.02491983398795128, + -0.03023526817560196, + -0.0383906215429306, + 0.005803969223052263, + 0.07864317297935486 + ], + "state_costate_cos": [ + 0.9431054592132568, + 0.9437580108642578, + 0.9442368149757385, + 0.9451977014541626, + 0.9453786611557007, + 0.9453213214874268, + 0.9453005790710449, + 0.945610761642456, + 0.9459618330001831, + 0.946029007434845, + 0.9460264444351196, + 0.9459818005561829 + ], + "credit_costate_cos": [ + 0.9404189586639404, + 0.9415367841720581, + 0.9419513940811157, + 0.9428710341453552, + 0.943200409412384, + 0.9437626004219055, + 0.9443795680999756, + 0.9448325037956238, + 0.945106029510498, + 0.9459753036499023, + 0.9465078115463257, + 0.9469560384750366 + ], + "dfa_rho": [ + 0.04672882705926895, + -0.030157513916492462, + 0.03139471262693405, + -0.02426629513502121, + -0.003147948533296585, + 0.06672847270965576, + -0.0028628872241824865, + -0.0006099608726799488, + -0.06767985224723816, + -0.07952223718166351, + 0.023338939994573593, + 0.13807083666324615 + ], + "state_rho": [ + 0.9322999715805054, + 0.9332318902015686, + 0.933842658996582, + 0.9367775917053223, + 0.9355518221855164, + 0.9306066036224365, + 0.9324768781661987, + 0.9316605925559998, + 0.9233090281486511, + 0.9335967302322388, + 0.9341922998428345, + 0.933304488658905 + ], + "credit_rho": [ + 0.929013192653656, + 0.925915002822876, + 0.9249007701873779, + 0.9293269515037537, + 0.9353340268135071, + 0.9355732202529907, + 0.9336704611778259, + 0.9395359754562378, + 0.934350848197937, + 0.9350711107254028, + 0.9423503279685974, + 0.9341102838516235 + ], + "dfa_nudge": [ + -0.0025230227038264275, + 0.01772259920835495, + -0.011610012501478195, + 0.01727830246090889, + -0.002929902635514736, + -0.01679525338113308, + 0.0033558662980794907, + -0.007272847928106785, + 0.013050400651991367, + 0.014434966258704662, + -0.0013312064111232758, + -0.02895473688840866 + ], + "state_nudge": [ + -0.3518102467060089, + -0.3520665764808655, + -0.35131847858428955, + -0.35165226459503174, + -0.35131698846817017, + -0.35155874490737915, + -0.3516783118247986, + -0.3522875905036926, + -0.35148167610168457, + -0.3511171340942383, + -0.35132837295532227, + -0.3511199951171875 + ], + "credit_nudge": [ + -0.34800052642822266, + -0.3485172390937805, + -0.3479737639427185, + -0.3486989140510559, + -0.3486787676811218, + -0.3492557406425476, + -0.3497552275657654, + -0.35073035955429077, + -0.35016047954559326, + -0.3501723110675812, + -0.3507159650325775, + -0.35094690322875977 + ] + } +}
\ No newline at end of file diff --git a/results/toy_lq/value_net_seed42.pt b/results/toy_lq/value_net_seed42.pt Binary files differnew file mode 100644 index 0000000..0cb6683 --- /dev/null +++ b/results/toy_lq/value_net_seed42.pt |
