""" Plot the §4 penalty-rescue figure: per-epoch trajectories of the protocol's (a) and (b) diagnostics on vanilla DFA vs penalized DFA, plus accuracy. Shows visually that the penalty rescues both diagnostics back into the healthy regime, while the headline accuracy gain is small (the second failure mode persists). Data sources: - vanilla DFA trajectory: results/snapshot_evolution_v2/snapshot_evolution_s42.json - penalized DFA (lam=1e-2): results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json - DFA-shallow baseline 3-seed mean (drawn as horizontal line): 0.349 - BP-trainable 3-seed mean: 0.6147 (100 ep) / 0.585 (matched 30 ep) Run: python -m protocol.examples.plot_penalty_rescue """ import os import sys import json import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) def load_vanilla_log(): with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f: d = json.load(f) log = d["dfa_log"] return [ { "epoch": e["epoch"], "h_L": e["hidden_norms"][-1], "g_2": e["bp_grad_norms_per_sample_med"][2], "acc": e["acc_eval"], } for e in log ] def load_penalty_log(): with open(os.path.join(REPO_ROOT, "results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json")) as f: d = json.load(f) return d["log"], d["final_test_acc"] def main(): vanilla = load_vanilla_log() penalty, penalty_final_acc = load_penalty_log() fig, axes = plt.subplots(1, 3, figsize=(13, 4.0), gridspec_kw={"wspace": 0.35}) # --- (1) ‖h_L‖ --- ax = axes[0] ax.plot([e["epoch"] for e in vanilla], [e["h_L"] for e in vanilla], label="vanilla DFA", color="C3", lw=2) ax.plot([e["epoch"] for e in penalty], [e["h_L_norm"] for e in penalty], label=r"DFA + $\lambda \|f_l\|^2$ ($\lambda=10^{-2}$)", color="C2", lw=2, marker="o", markersize=4) ax.axhline(200, color="C0", linestyle=":", lw=1, label=r"BP $\|h_L\| \approx 200$") ax.set_yscale("log") ax.set_xlabel("epoch", fontsize=10) ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10) ax.set_title("(a) residual stream", fontsize=11) ax.legend(loc="lower right", fontsize=7.5) ax.grid(True, which="both", alpha=0.3) # --- (2) ‖g_2‖ --- ax = axes[1] ax.plot([e["epoch"] for e in vanilla], [e["g_2"] for e in vanilla], label="vanilla DFA", color="C3", lw=2) ax.plot([e["epoch"] for e in penalty], [e["g_2_norm"] for e in penalty], label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4) ax.axhline(1e-7, color="k", linestyle="--", lw=1.2, label=r"floor $10^{-7}$") ax.set_yscale("log") ax.set_xlabel("epoch", fontsize=10) ax.set_ylabel(r"$\|g_L\|_2$ (log)", fontsize=10) ax.set_title("(b) BP grad at hidden layer", fontsize=11) ax.legend(loc="lower left", fontsize=7.5) ax.grid(True, which="both", alpha=0.3) # --- (3) accuracy --- ax = axes[2] ax.plot([e["epoch"] for e in vanilla], [e["acc"] for e in vanilla], label="vanilla DFA", color="C3", lw=2) ax.plot([e["epoch"] for e in penalty], [e["acc_eval"] for e in penalty], label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4) ax.axhline(0.349, color="k", linestyle="--", lw=1.2, label="DFA-shallow 0.349") ax.axhline(0.6147, color="C0", linestyle=":", lw=1, label="BP-trainable 100ep 0.615") ax.set_xlabel("epoch", fontsize=10) ax.set_ylabel("test acc", fontsize=10) ax.set_title("(d) headline accuracy", fontsize=11) ax.legend(loc="lower right", fontsize=7.5) ax.grid(True, alpha=0.3) ax.set_ylim(0.05, 0.7) fig.suptitle( r"Penalty rescue (4-block d=256 ResMLP, seed 42): the $\|f_l\|^2$ penalty fixes (a) and (b)," "\nbut the deep blocks still fail to clear the frozen baseline (the second failure mode)", fontsize=11, y=1.05 ) fig.tight_layout() out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_penalty_rescue_s42.png") fig.savefig(out_path, dpi=140, bbox_inches="tight") print(f"Saved {out_path}") if __name__ == "__main__": main()