summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:32:44 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:32:44 -0500
commitd5185a3cc692fe96c93bbc5d7b286b7080ba7458 (patch)
tree7872cdc972242082d71e39ebe1c414f0dd1bb636 /protocol
parentcb0e6b3f3e9c3d0cb8335be1621478cf4c786375 (diff)
Add §4 penalty rescue figure: visual two-failure-modes story
3-panel side-by-side showing per-epoch trajectories of vanilla DFA vs DFA + lambda*||f||^2 penalty: (a) ||h_L||: vanilla 4e8 vs penalty 4e4 (4 OOM rescue) (b) ||g_L||: vanilla 5e-10 vs penalty ~1e-6 (4 OOM rescue) (d) test acc: vanilla 0.31 vs penalty 0.36 vs frozen baseline 0.349 vs BP 0.61 The visual story: (a) and (b) show the penalty pulling the diagnostics back into the healthy regime, but (d) shows the rescue translates to only +1 pp above the DFA-shallow baseline and 24 pp below BP-trainable. The two failure modes (scale + direction) are visually separable: scale is fixed, direction is not. Together with figure_audit_5method.png and figure_cross_arch_temporal_s42.png, this is the third paper-ready figure for §3-§4.
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/plot_penalty_rescue.py114
1 files changed, 114 insertions, 0 deletions
diff --git a/protocol/examples/plot_penalty_rescue.py b/protocol/examples/plot_penalty_rescue.py
new file mode 100644
index 0000000..37b0fa9
--- /dev/null
+++ b/protocol/examples/plot_penalty_rescue.py
@@ -0,0 +1,114 @@
+"""
+Plot the §4 penalty-rescue figure: per-epoch trajectories of the protocol's
+(a) and (b) diagnostics on vanilla DFA vs penalized DFA, plus accuracy.
+
+Shows visually that the penalty rescues both diagnostics back into the
+healthy regime, while the headline accuracy gain is small (the second
+failure mode persists).
+
+Data sources:
+ - vanilla DFA trajectory: results/snapshot_evolution_v2/snapshot_evolution_s42.json
+ - penalized DFA (lam=1e-2): results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json
+ - DFA-shallow baseline 3-seed mean (drawn as horizontal line): 0.349
+ - BP-trainable 3-seed mean: 0.609
+
+Run:
+ python -m protocol.examples.plot_penalty_rescue
+"""
+import os
+import sys
+import json
+
+import numpy as np
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+
+
+def load_vanilla_log():
+ with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f:
+ d = json.load(f)
+ log = d["dfa_log"]
+ return [
+ {
+ "epoch": e["epoch"],
+ "h_L": e["hidden_norms"][-1],
+ "g_2": e["bp_grad_norms_per_sample_med"][2],
+ "acc": e["acc_eval"],
+ }
+ for e in log
+ ]
+
+
+def load_penalty_log():
+ with open(os.path.join(REPO_ROOT, "results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json")) as f:
+ d = json.load(f)
+ return d["log"], d["final_test_acc"]
+
+
+def main():
+ vanilla = load_vanilla_log()
+ penalty, penalty_final_acc = load_penalty_log()
+
+ fig, axes = plt.subplots(1, 3, figsize=(13, 4.0), gridspec_kw={"wspace": 0.35})
+
+ # --- (1) ‖h_L‖ ---
+ ax = axes[0]
+ ax.plot([e["epoch"] for e in vanilla], [e["h_L"] for e in vanilla],
+ label="vanilla DFA", color="C3", lw=2)
+ ax.plot([e["epoch"] for e in penalty], [e["h_L_norm"] for e in penalty],
+ label=r"DFA + $\lambda \|f_l\|^2$ ($\lambda=10^{-2}$)", color="C2", lw=2, marker="o", markersize=4)
+ ax.axhline(200, color="C0", linestyle=":", lw=1, label=r"BP $\|h_L\| \approx 200$")
+ ax.set_yscale("log")
+ ax.set_xlabel("epoch", fontsize=10)
+ ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10)
+ ax.set_title("(a) residual stream", fontsize=11)
+ ax.legend(loc="lower right", fontsize=7.5)
+ ax.grid(True, which="both", alpha=0.3)
+
+ # --- (2) ‖g_2‖ ---
+ ax = axes[1]
+ ax.plot([e["epoch"] for e in vanilla], [e["g_2"] for e in vanilla],
+ label="vanilla DFA", color="C3", lw=2)
+ ax.plot([e["epoch"] for e in penalty], [e["g_2_norm"] for e in penalty],
+ label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4)
+ ax.axhline(1e-7, color="k", linestyle="--", lw=1.2, label=r"floor $10^{-7}$")
+ ax.set_yscale("log")
+ ax.set_xlabel("epoch", fontsize=10)
+ ax.set_ylabel(r"$\|g_L\|_2$ (log)", fontsize=10)
+ ax.set_title("(b) BP grad at hidden layer", fontsize=11)
+ ax.legend(loc="lower left", fontsize=7.5)
+ ax.grid(True, which="both", alpha=0.3)
+
+ # --- (3) accuracy ---
+ ax = axes[2]
+ ax.plot([e["epoch"] for e in vanilla], [e["acc"] for e in vanilla],
+ label="vanilla DFA", color="C3", lw=2)
+ ax.plot([e["epoch"] for e in penalty], [e["acc_eval"] for e in penalty],
+ label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4)
+ ax.axhline(0.349, color="k", linestyle="--", lw=1.2, label="DFA-shallow 0.349")
+ ax.axhline(0.609, color="C0", linestyle=":", lw=1, label="BP-trainable 0.609")
+ ax.set_xlabel("epoch", fontsize=10)
+ ax.set_ylabel("test acc", fontsize=10)
+ ax.set_title("(d) headline accuracy", fontsize=11)
+ ax.legend(loc="lower right", fontsize=7.5)
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(0.05, 0.7)
+
+ fig.suptitle(
+ r"Penalty rescue (4-block d=256 ResMLP, seed 42): the $\|f_l\|^2$ penalty fixes (a) and (b),"
+ "\nbut the deep blocks still fail to clear the frozen baseline (the second failure mode)",
+ fontsize=11, y=1.05
+ )
+ fig.tight_layout()
+ out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_penalty_rescue_s42.png")
+ fig.savefig(out_path, dpi=140, bbox_inches="tight")
+ print(f"Saved {out_path}")
+
+
+if __name__ == "__main__":
+ main()