""" Temporal trajectory of the protocol's (a) and (b) diagnostics under DFA + residual-branch penalty, vs vanilla DFA. Uses the existing per-epoch logs: - vanilla DFA: results/snapshot_evolution_v2/snapshot_evolution_s42.json - penalized DFA (3 lam values): results/dfa_residual_penalty/dfa_pen_lam{0.001,0.01,0.1}_s42.json Run: python -m protocol.examples.temporal_penalty_vs_vanilla """ import os import sys import json import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) def load_vanilla_log(): with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f: d = json.load(f) log = d["dfa_log"] return [ {"epoch": e["epoch"], "h_L": e["hidden_norms"][-1], "g_2": e["bp_grad_norms_per_sample_med"][2], "acc": e["acc_eval"]} for e in log ] def load_penalty_log(lam): path = os.path.join(REPO_ROOT, f"results/dfa_residual_penalty/dfa_pen_lam{lam}_s42.json") if not os.path.exists(path): return None with open(path) as f: d = json.load(f) return d["log"] def main(): vanilla = load_vanilla_log() p_1e3 = load_penalty_log("0.001") p_1e2 = load_penalty_log("0.01") # p_1e1 = load_penalty_log("0.1") # was killed mid-run fig, axes = plt.subplots(1, 3, figsize=(13, 4.0), gridspec_kw={"wspace": 0.35}) plots = [ ("vanilla DFA", vanilla, "h_L", "g_2", "acc", "C3", None, "-", 2), ] if p_1e3: plots.append(("DFA + λ=1e-3", p_1e3, "h_L_norm", "g_2_norm", "acc_eval", "C2", "o", "-", 2)) if p_1e2: plots.append(("DFA + λ=1e-2", p_1e2, "h_L_norm", "g_2_norm", "acc_eval", "C0", "s", "-", 2)) # --- (a) ‖h_L‖ --- ax = axes[0] for label, log, hk, gk, ak, c, m, ls, lw in plots: ax.plot([e["epoch"] for e in log], [e[hk] for e in log], label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls) ax.axhline(200, color="gray", linestyle=":", lw=1, label="BP ~200") ax.set_yscale("log") ax.set_xlabel("epoch", fontsize=10) ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10) ax.set_title("(a) residual stream", fontsize=11) ax.legend(loc="lower right", fontsize=8) ax.grid(True, which="both", alpha=0.3) # --- (b) ‖g_2‖ --- ax = axes[1] for label, log, hk, gk, ak, c, m, ls, lw in plots: ax.plot([e["epoch"] for e in log], [e[gk] for e in log], label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls) ax.axhline(1e-7, color="k", linestyle="--", lw=1.2, label="floor 1e-7") ax.set_yscale("log") ax.set_xlabel("epoch", fontsize=10) ax.set_ylabel(r"$\|g_2\|_2$ (log)", fontsize=10) ax.set_title("(b) BP grad floor", fontsize=11) ax.legend(loc="lower left", fontsize=8) ax.grid(True, which="both", alpha=0.3) # --- (c) accuracy --- ax = axes[2] for label, log, hk, gk, ak, c, m, ls, lw in plots: ax.plot([e["epoch"] for e in log], [e[ak] for e in log], label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls) ax.axhline(0.349, color="k", linestyle="--", lw=1.0, label="DFA-shallow 0.349") ax.axhline(0.371, color="purple", linestyle=":", lw=1.0, label="2pp threshold 0.371") ax.set_xlabel("epoch", fontsize=10) ax.set_ylabel("test acc", fontsize=10) ax.set_title("acc + frozen baseline + (d) threshold", fontsize=11) ax.legend(loc="lower right", fontsize=7.5) ax.grid(True, alpha=0.3) ax.set_ylim(0.05, 0.55) fig.suptitle( "Penalty rescue: (a) and (b) cleanly fix; (d) verdict depends on λ choice", fontsize=11, y=1.05 ) fig.tight_layout() out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_penalty_lambda_sweep.png") fig.savefig(out_path, dpi=140, bbox_inches="tight") print(f"Saved {out_path}") if __name__ == "__main__": main()