diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:32:44 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:32:44 -0500 |
| commit | d5185a3cc692fe96c93bbc5d7b286b7080ba7458 (patch) | |
| tree | 7872cdc972242082d71e39ebe1c414f0dd1bb636 /protocol | |
| parent | cb0e6b3f3e9c3d0cb8335be1621478cf4c786375 (diff) | |
Add §4 penalty rescue figure: visual two-failure-modes story
3-panel side-by-side showing per-epoch trajectories of vanilla DFA vs
DFA + lambda*||f||^2 penalty:
(a) ||h_L||: vanilla 4e8 vs penalty 4e4 (4 OOM rescue)
(b) ||g_L||: vanilla 5e-10 vs penalty ~1e-6 (4 OOM rescue)
(d) test acc: vanilla 0.31 vs penalty 0.36 vs frozen baseline 0.349 vs BP 0.61
The visual story: (a) and (b) show the penalty pulling the diagnostics
back into the healthy regime, but (d) shows the rescue translates to
only +1 pp above the DFA-shallow baseline and 24 pp below BP-trainable.
The two failure modes (scale + direction) are visually separable: scale
is fixed, direction is not.
Together with figure_audit_5method.png and figure_cross_arch_temporal_s42.png,
this is the third paper-ready figure for §3-§4.
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/examples/plot_penalty_rescue.py | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/protocol/examples/plot_penalty_rescue.py b/protocol/examples/plot_penalty_rescue.py new file mode 100644 index 0000000..37b0fa9 --- /dev/null +++ b/protocol/examples/plot_penalty_rescue.py @@ -0,0 +1,114 @@ +""" +Plot the §4 penalty-rescue figure: per-epoch trajectories of the protocol's +(a) and (b) diagnostics on vanilla DFA vs penalized DFA, plus accuracy. + +Shows visually that the penalty rescues both diagnostics back into the +healthy regime, while the headline accuracy gain is small (the second +failure mode persists). + +Data sources: + - vanilla DFA trajectory: results/snapshot_evolution_v2/snapshot_evolution_s42.json + - penalized DFA (lam=1e-2): results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json + - DFA-shallow baseline 3-seed mean (drawn as horizontal line): 0.349 + - BP-trainable 3-seed mean: 0.609 + +Run: + python -m protocol.examples.plot_penalty_rescue +""" +import os +import sys +import json + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) + + +def load_vanilla_log(): + with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f: + d = json.load(f) + log = d["dfa_log"] + return [ + { + "epoch": e["epoch"], + "h_L": e["hidden_norms"][-1], + "g_2": e["bp_grad_norms_per_sample_med"][2], + "acc": e["acc_eval"], + } + for e in log + ] + + +def load_penalty_log(): + with open(os.path.join(REPO_ROOT, "results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json")) as f: + d = json.load(f) + return d["log"], d["final_test_acc"] + + +def main(): + vanilla = load_vanilla_log() + penalty, penalty_final_acc = load_penalty_log() + + fig, axes = plt.subplots(1, 3, figsize=(13, 4.0), gridspec_kw={"wspace": 0.35}) + + # --- (1) ‖h_L‖ --- + ax = axes[0] + ax.plot([e["epoch"] for e in vanilla], [e["h_L"] for e in vanilla], + label="vanilla DFA", color="C3", lw=2) + ax.plot([e["epoch"] for e in penalty], [e["h_L_norm"] for e in penalty], + label=r"DFA + $\lambda \|f_l\|^2$ ($\lambda=10^{-2}$)", color="C2", lw=2, marker="o", markersize=4) + ax.axhline(200, color="C0", linestyle=":", lw=1, label=r"BP $\|h_L\| \approx 200$") + ax.set_yscale("log") + ax.set_xlabel("epoch", fontsize=10) + ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10) + ax.set_title("(a) residual stream", fontsize=11) + ax.legend(loc="lower right", fontsize=7.5) + ax.grid(True, which="both", alpha=0.3) + + # --- (2) ‖g_2‖ --- + ax = axes[1] + ax.plot([e["epoch"] for e in vanilla], [e["g_2"] for e in vanilla], + label="vanilla DFA", color="C3", lw=2) + ax.plot([e["epoch"] for e in penalty], [e["g_2_norm"] for e in penalty], + label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4) + ax.axhline(1e-7, color="k", linestyle="--", lw=1.2, label=r"floor $10^{-7}$") + ax.set_yscale("log") + ax.set_xlabel("epoch", fontsize=10) + ax.set_ylabel(r"$\|g_L\|_2$ (log)", fontsize=10) + ax.set_title("(b) BP grad at hidden layer", fontsize=11) + ax.legend(loc="lower left", fontsize=7.5) + ax.grid(True, which="both", alpha=0.3) + + # --- (3) accuracy --- + ax = axes[2] + ax.plot([e["epoch"] for e in vanilla], [e["acc"] for e in vanilla], + label="vanilla DFA", color="C3", lw=2) + ax.plot([e["epoch"] for e in penalty], [e["acc_eval"] for e in penalty], + label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4) + ax.axhline(0.349, color="k", linestyle="--", lw=1.2, label="DFA-shallow 0.349") + ax.axhline(0.609, color="C0", linestyle=":", lw=1, label="BP-trainable 0.609") + ax.set_xlabel("epoch", fontsize=10) + ax.set_ylabel("test acc", fontsize=10) + ax.set_title("(d) headline accuracy", fontsize=11) + ax.legend(loc="lower right", fontsize=7.5) + ax.grid(True, alpha=0.3) + ax.set_ylim(0.05, 0.7) + + fig.suptitle( + r"Penalty rescue (4-block d=256 ResMLP, seed 42): the $\|f_l\|^2$ penalty fixes (a) and (b)," + "\nbut the deep blocks still fail to clear the frozen baseline (the second failure mode)", + fontsize=11, y=1.05 + ) + fig.tight_layout() + out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_penalty_rescue_s42.png") + fig.savefig(out_path, dpi=140, bbox_inches="tight") + print(f"Saved {out_path}") + + +if __name__ == "__main__": + main() |
