From d5185a3cc692fe96c93bbc5d7b286b7080ba7458 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 7 Apr 2026 23:32:44 -0500 Subject: =?UTF-8?q?Add=20=C2=A74=20penalty=20rescue=20figure:=20visual=20t?= =?UTF-8?q?wo-failure-modes=20story?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 3-panel side-by-side showing per-epoch trajectories of vanilla DFA vs DFA + lambda*||f||^2 penalty: (a) ||h_L||: vanilla 4e8 vs penalty 4e4 (4 OOM rescue) (b) ||g_L||: vanilla 5e-10 vs penalty ~1e-6 (4 OOM rescue) (d) test acc: vanilla 0.31 vs penalty 0.36 vs frozen baseline 0.349 vs BP 0.61 The visual story: (a) and (b) show the penalty pulling the diagnostics back into the healthy regime, but (d) shows the rescue translates to only +1 pp above the DFA-shallow baseline and 24 pp below BP-trainable. The two failure modes (scale + direction) are visually separable: scale is fixed, direction is not. Together with figure_audit_5method.png and figure_cross_arch_temporal_s42.png, this is the third paper-ready figure for §3-§4. --- protocol/examples/plot_penalty_rescue.py | 114 +++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 protocol/examples/plot_penalty_rescue.py (limited to 'protocol/examples') diff --git a/protocol/examples/plot_penalty_rescue.py b/protocol/examples/plot_penalty_rescue.py new file mode 100644 index 0000000..37b0fa9 --- /dev/null +++ b/protocol/examples/plot_penalty_rescue.py @@ -0,0 +1,114 @@ +""" +Plot the §4 penalty-rescue figure: per-epoch trajectories of the protocol's +(a) and (b) diagnostics on vanilla DFA vs penalized DFA, plus accuracy. + +Shows visually that the penalty rescues both diagnostics back into the +healthy regime, while the headline accuracy gain is small (the second +failure mode persists). + +Data sources: + - vanilla DFA trajectory: results/snapshot_evolution_v2/snapshot_evolution_s42.json + - penalized DFA (lam=1e-2): results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json + - DFA-shallow baseline 3-seed mean (drawn as horizontal line): 0.349 + - BP-trainable 3-seed mean: 0.609 + +Run: + python -m protocol.examples.plot_penalty_rescue +""" +import os +import sys +import json + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) + + +def load_vanilla_log(): + with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f: + d = json.load(f) + log = d["dfa_log"] + return [ + { + "epoch": e["epoch"], + "h_L": e["hidden_norms"][-1], + "g_2": e["bp_grad_norms_per_sample_med"][2], + "acc": e["acc_eval"], + } + for e in log + ] + + +def load_penalty_log(): + with open(os.path.join(REPO_ROOT, "results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json")) as f: + d = json.load(f) + return d["log"], d["final_test_acc"] + + +def main(): + vanilla = load_vanilla_log() + penalty, penalty_final_acc = load_penalty_log() + + fig, axes = plt.subplots(1, 3, figsize=(13, 4.0), gridspec_kw={"wspace": 0.35}) + + # --- (1) ‖h_L‖ --- + ax = axes[0] + ax.plot([e["epoch"] for e in vanilla], [e["h_L"] for e in vanilla], + label="vanilla DFA", color="C3", lw=2) + ax.plot([e["epoch"] for e in penalty], [e["h_L_norm"] for e in penalty], + label=r"DFA + $\lambda \|f_l\|^2$ ($\lambda=10^{-2}$)", color="C2", lw=2, marker="o", markersize=4) + ax.axhline(200, color="C0", linestyle=":", lw=1, label=r"BP $\|h_L\| \approx 200$") + ax.set_yscale("log") + ax.set_xlabel("epoch", fontsize=10) + ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10) + ax.set_title("(a) residual stream", fontsize=11) + ax.legend(loc="lower right", fontsize=7.5) + ax.grid(True, which="both", alpha=0.3) + + # --- (2) ‖g_2‖ --- + ax = axes[1] + ax.plot([e["epoch"] for e in vanilla], [e["g_2"] for e in vanilla], + label="vanilla DFA", color="C3", lw=2) + ax.plot([e["epoch"] for e in penalty], [e["g_2_norm"] for e in penalty], + label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4) + ax.axhline(1e-7, color="k", linestyle="--", lw=1.2, label=r"floor $10^{-7}$") + ax.set_yscale("log") + ax.set_xlabel("epoch", fontsize=10) + ax.set_ylabel(r"$\|g_L\|_2$ (log)", fontsize=10) + ax.set_title("(b) BP grad at hidden layer", fontsize=11) + ax.legend(loc="lower left", fontsize=7.5) + ax.grid(True, which="both", alpha=0.3) + + # --- (3) accuracy --- + ax = axes[2] + ax.plot([e["epoch"] for e in vanilla], [e["acc"] for e in vanilla], + label="vanilla DFA", color="C3", lw=2) + ax.plot([e["epoch"] for e in penalty], [e["acc_eval"] for e in penalty], + label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4) + ax.axhline(0.349, color="k", linestyle="--", lw=1.2, label="DFA-shallow 0.349") + ax.axhline(0.609, color="C0", linestyle=":", lw=1, label="BP-trainable 0.609") + ax.set_xlabel("epoch", fontsize=10) + ax.set_ylabel("test acc", fontsize=10) + ax.set_title("(d) headline accuracy", fontsize=11) + ax.legend(loc="lower right", fontsize=7.5) + ax.grid(True, alpha=0.3) + ax.set_ylim(0.05, 0.7) + + fig.suptitle( + r"Penalty rescue (4-block d=256 ResMLP, seed 42): the $\|f_l\|^2$ penalty fixes (a) and (b)," + "\nbut the deep blocks still fail to clear the frozen baseline (the second failure mode)", + fontsize=11, y=1.05 + ) + fig.tight_layout() + out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_penalty_rescue_s42.png") + fig.savefig(out_path, dpi=140, bbox_inches="tight") + print(f"Saved {out_path}") + + +if __name__ == "__main__": + main() -- cgit v1.2.3