summaryrefslogtreecommitdiff
path: root/paper/figures/render_fig4_penalty_rescue.py
diff options
context:
space:
mode:
Diffstat (limited to 'paper/figures/render_fig4_penalty_rescue.py')
-rw-r--r--paper/figures/render_fig4_penalty_rescue.py92
1 files changed, 92 insertions, 0 deletions
diff --git a/paper/figures/render_fig4_penalty_rescue.py b/paper/figures/render_fig4_penalty_rescue.py
new file mode 100644
index 0000000..b7089ec
--- /dev/null
+++ b/paper/figures/render_fig4_penalty_rescue.py
@@ -0,0 +1,92 @@
+"""Render Figure 4: penalty rescue + capacity-cost control."""
+import os
+import json
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import numpy as np
+
+REPO_ROOT = "/home/yurenh2/fa"
+
+# Panel A: penalty rescue trajectory
+with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f:
+ snap = json.load(f)
+vanilla = snap["dfa_log"]
+ep_vanilla = [e["epoch"] for e in vanilla]
+hL_vanilla = [e["hidden_norms"][-1] for e in vanilla]
+g_vanilla = [e["bp_grad_norms_per_sample_med"][-1] for e in vanilla]
+
+with open(os.path.join(REPO_ROOT, "results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json")) as f:
+ pen = json.load(f)
+ep_pen = [e["epoch"] for e in pen["log"]]
+hL_pen = [e["h_L_norm"] for e in pen["log"]]
+g_pen = [e["g_2_norm"] for e in pen["log"]]
+
+# Panel B: cosine + rho across vanilla / penalized / fresh-B / penalty lam=1e-4
+# Read from existing results
+conditions = ["vanilla\nDFA\n(early)", "penalized\n$\\lambda{=}10^{-4}$", "penalized\n$\\lambda{=}10^{-2}$", "fresh-$B$\nnull", "BP grad\n(positive)"]
+deep_cos = [-0.008, -0.022, +0.155, +0.002, +1.000]
+deep_rho = [-0.003, -0.004, +0.080, +0.006, +0.997]
+cos_err = [0.013, 0.0, 0.025, 0.022, 0.0]
+rho_err = [0.005, 0.0, 0.011, 0.0, 0.0]
+
+# Panel C: 2x2 capacity-cost control
+methods = ["BP", "DFA"]
+no_pen = [0.609, 0.308]
+with_pen = [0.530, 0.363]
+shallow = 0.349
+
+fig, axes = plt.subplots(1, 3, figsize=(13, 3.5))
+
+# Panel A: trajectory
+ax = axes[0]
+ax.plot(ep_vanilla, hL_vanilla, label="vanilla DFA $\\|h_L\\|$", color="C3", lw=1.5, marker="o", markersize=3)
+ax.plot(ep_pen, hL_pen, label="penalized DFA $\\|h_L\\|$ ($\\lambda{=}10^{-2}$)", color="C2", lw=1.5, marker="s", markersize=3)
+ax.set_yscale("log")
+ax.set_xlabel("epoch", fontsize=10)
+ax.set_ylabel("$\\|h_L\\|$ (log)", fontsize=10)
+ax.set_title("(a) penalty contains residual stream\n(4 OOM rescue)", fontsize=10)
+ax.legend(loc="lower right", fontsize=8)
+ax.grid(True, alpha=0.3, which="both")
+ax2 = ax.twinx()
+ax2.plot(ep_vanilla, g_vanilla, label="vanilla $\\|g\\|$", color="C3", lw=1, ls=":", marker="^", markersize=3)
+ax2.plot(ep_pen, g_pen, label="penalized $\\|g\\|$", color="C2", lw=1, ls=":", marker="v", markersize=3)
+ax2.axhline(1e-7, color="black", ls="--", lw=0.8, label="$10^{-7}$ floor")
+ax2.set_yscale("log")
+ax2.set_ylabel("$\\|g_L\\|$ (log)", fontsize=9, color="gray")
+ax2.tick_params(axis="y", labelcolor="gray")
+
+# Panel B: cosine + rho
+ax = axes[1]
+xpos = np.arange(len(conditions))
+w = 0.35
+b1 = ax.bar(xpos - w/2, deep_cos, w, yerr=cos_err, label="deep cos", color="#4682b4", capsize=3)
+b2 = ax.bar(xpos + w/2, deep_rho, w, yerr=rho_err, label="deep $\\rho$", color="#7da76f", capsize=3)
+ax.axhline(0, color="black", lw=0.5)
+ax.set_xticks(xpos)
+ax.set_xticklabels(conditions, fontsize=8)
+ax.set_ylabel("deep-layer alignment", fontsize=10)
+ax.set_title("(b) two metrics agree across conditions\n(measurement vs random feedback)", fontsize=10)
+ax.legend(loc="upper left", fontsize=8)
+ax.grid(True, axis="y", alpha=0.3)
+ax.set_ylim(-0.1, 1.1)
+
+# Panel C: 2x2 capacity-cost
+ax = axes[2]
+xpos = np.arange(len(methods))
+w = 0.35
+ax.bar(xpos - w/2, no_pen, w, label="no penalty", color="#4682b4")
+ax.bar(xpos + w/2, with_pen, w, label="with penalty $\\lambda{=}10^{-2}$", color="#cc4444")
+ax.axhline(shallow, color="black", ls="--", lw=1, label=f"frozen baseline {shallow}")
+ax.set_xticks(xpos)
+ax.set_xticklabels(methods, fontsize=10)
+ax.set_ylabel("test accuracy", fontsize=10)
+ax.set_title("(c) BP+penalty 2$\\times$2 control\n(BP-pen-cost $-8$pp; gap $17$pp $=$ credit quality)", fontsize=10)
+ax.legend(loc="upper right", fontsize=8)
+ax.grid(True, axis="y", alpha=0.3)
+ax.set_ylim(0, 0.7)
+
+plt.tight_layout()
+out = os.path.join(REPO_ROOT, "paper/figures/fig4_penalty_rescue.pdf")
+plt.savefig(out, bbox_inches="tight", dpi=200)
+print(f"Saved {out}")