diff options
| -rw-r--r-- | protocol/examples/plot_penalty_rescue.py | 114 | ||||
| -rw-r--r-- | results/protocol_audit/figure_penalty_rescue_s42.png | bin | 0 -> 132144 bytes |
2 files changed, 114 insertions, 0 deletions
diff --git a/protocol/examples/plot_penalty_rescue.py b/protocol/examples/plot_penalty_rescue.py new file mode 100644 index 0000000..37b0fa9 --- /dev/null +++ b/protocol/examples/plot_penalty_rescue.py @@ -0,0 +1,114 @@ +""" +Plot the §4 penalty-rescue figure: per-epoch trajectories of the protocol's +(a) and (b) diagnostics on vanilla DFA vs penalized DFA, plus accuracy. + +Shows visually that the penalty rescues both diagnostics back into the +healthy regime, while the headline accuracy gain is small (the second +failure mode persists). + +Data sources: + - vanilla DFA trajectory: results/snapshot_evolution_v2/snapshot_evolution_s42.json + - penalized DFA (lam=1e-2): results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json + - DFA-shallow baseline 3-seed mean (drawn as horizontal line): 0.349 + - BP-trainable 3-seed mean: 0.609 + +Run: + python -m protocol.examples.plot_penalty_rescue +""" +import os +import sys +import json + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) + + +def load_vanilla_log(): + with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f: + d = json.load(f) + log = d["dfa_log"] + return [ + { + "epoch": e["epoch"], + "h_L": e["hidden_norms"][-1], + "g_2": e["bp_grad_norms_per_sample_med"][2], + "acc": e["acc_eval"], + } + for e in log + ] + + +def load_penalty_log(): + with open(os.path.join(REPO_ROOT, "results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json")) as f: + d = json.load(f) + return d["log"], d["final_test_acc"] + + +def main(): + vanilla = load_vanilla_log() + penalty, penalty_final_acc = load_penalty_log() + + fig, axes = plt.subplots(1, 3, figsize=(13, 4.0), gridspec_kw={"wspace": 0.35}) + + # --- (1) ‖h_L‖ --- + ax = axes[0] + ax.plot([e["epoch"] for e in vanilla], [e["h_L"] for e in vanilla], + label="vanilla DFA", color="C3", lw=2) + ax.plot([e["epoch"] for e in penalty], [e["h_L_norm"] for e in penalty], + label=r"DFA + $\lambda \|f_l\|^2$ ($\lambda=10^{-2}$)", color="C2", lw=2, marker="o", markersize=4) + ax.axhline(200, color="C0", linestyle=":", lw=1, label=r"BP $\|h_L\| \approx 200$") + ax.set_yscale("log") + ax.set_xlabel("epoch", fontsize=10) + ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10) + ax.set_title("(a) residual stream", fontsize=11) + ax.legend(loc="lower right", fontsize=7.5) + ax.grid(True, which="both", alpha=0.3) + + # --- (2) ‖g_2‖ --- + ax = axes[1] + ax.plot([e["epoch"] for e in vanilla], [e["g_2"] for e in vanilla], + label="vanilla DFA", color="C3", lw=2) + ax.plot([e["epoch"] for e in penalty], [e["g_2_norm"] for e in penalty], + label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4) + ax.axhline(1e-7, color="k", linestyle="--", lw=1.2, label=r"floor $10^{-7}$") + ax.set_yscale("log") + ax.set_xlabel("epoch", fontsize=10) + ax.set_ylabel(r"$\|g_L\|_2$ (log)", fontsize=10) + ax.set_title("(b) BP grad at hidden layer", fontsize=11) + ax.legend(loc="lower left", fontsize=7.5) + ax.grid(True, which="both", alpha=0.3) + + # --- (3) accuracy --- + ax = axes[2] + ax.plot([e["epoch"] for e in vanilla], [e["acc"] for e in vanilla], + label="vanilla DFA", color="C3", lw=2) + ax.plot([e["epoch"] for e in penalty], [e["acc_eval"] for e in penalty], + label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4) + ax.axhline(0.349, color="k", linestyle="--", lw=1.2, label="DFA-shallow 0.349") + ax.axhline(0.609, color="C0", linestyle=":", lw=1, label="BP-trainable 0.609") + ax.set_xlabel("epoch", fontsize=10) + ax.set_ylabel("test acc", fontsize=10) + ax.set_title("(d) headline accuracy", fontsize=11) + ax.legend(loc="lower right", fontsize=7.5) + ax.grid(True, alpha=0.3) + ax.set_ylim(0.05, 0.7) + + fig.suptitle( + r"Penalty rescue (4-block d=256 ResMLP, seed 42): the $\|f_l\|^2$ penalty fixes (a) and (b)," + "\nbut the deep blocks still fail to clear the frozen baseline (the second failure mode)", + fontsize=11, y=1.05 + ) + fig.tight_layout() + out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_penalty_rescue_s42.png") + fig.savefig(out_path, dpi=140, bbox_inches="tight") + print(f"Saved {out_path}") + + +if __name__ == "__main__": + main() diff --git a/results/protocol_audit/figure_penalty_rescue_s42.png b/results/protocol_audit/figure_penalty_rescue_s42.png Binary files differnew file mode 100644 index 0000000..734de4f --- /dev/null +++ b/results/protocol_audit/figure_penalty_rescue_s42.png |
