diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 00:03:49 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 00:03:49 -0500 |
| commit | 76edf529be1b8aa8813ce380d104eaa424a3dc1d (patch) | |
| tree | b38df942983d246379c3121d7df6f4266c8c343e /protocol | |
| parent | 9d0e4901a82763ea3ebc57eea152a730330d4991 (diff) | |
Add penalty λ sweep figure: shows λ-dependence of (d) verdict
3-panel figure: vanilla DFA + penalty at λ=1e-3 (green) + penalty at
λ=1e-2 (blue):
(a) ‖h_L‖: vanilla 4e8, both penalties ~4e4 (similar)
(b) ‖g_2‖: vanilla 5e-10, penalties 7e-7 to 1e-6 (above floor)
(c) acc: vanilla 0.31, λ=1e-2 0.36, λ=1e-3 0.37; horizontal lines
at DFA-shallow 0.349 and 2pp threshold 0.371
Visual: at λ=1e-3 the test acc curve crosses ABOVE the 2pp threshold
line; at λ=1e-2 it stays below. This is the (d) lambda-dependence
finding from the round 18 follow-up.
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/examples/temporal_penalty_vs_vanilla.py | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/protocol/examples/temporal_penalty_vs_vanilla.py b/protocol/examples/temporal_penalty_vs_vanilla.py new file mode 100644 index 0000000..4c37418 --- /dev/null +++ b/protocol/examples/temporal_penalty_vs_vanilla.py @@ -0,0 +1,114 @@ +""" +Temporal trajectory of the protocol's (a) and (b) diagnostics under +DFA + residual-branch penalty, vs vanilla DFA. + +Uses the existing per-epoch logs: + - vanilla DFA: results/snapshot_evolution_v2/snapshot_evolution_s42.json + - penalized DFA (3 lam values): results/dfa_residual_penalty/dfa_pen_lam{0.001,0.01,0.1}_s42.json + +Run: + python -m protocol.examples.temporal_penalty_vs_vanilla +""" +import os +import sys +import json + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) + + +def load_vanilla_log(): + with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f: + d = json.load(f) + log = d["dfa_log"] + return [ + {"epoch": e["epoch"], + "h_L": e["hidden_norms"][-1], + "g_2": e["bp_grad_norms_per_sample_med"][2], + "acc": e["acc_eval"]} + for e in log + ] + + +def load_penalty_log(lam): + path = os.path.join(REPO_ROOT, f"results/dfa_residual_penalty/dfa_pen_lam{lam}_s42.json") + if not os.path.exists(path): + return None + with open(path) as f: + d = json.load(f) + return d["log"] + + +def main(): + vanilla = load_vanilla_log() + p_1e3 = load_penalty_log("0.001") + p_1e2 = load_penalty_log("0.01") + # p_1e1 = load_penalty_log("0.1") # was killed mid-run + + fig, axes = plt.subplots(1, 3, figsize=(13, 4.0), gridspec_kw={"wspace": 0.35}) + + plots = [ + ("vanilla DFA", vanilla, "h_L", "g_2", "acc", "C3", None, "-", 2), + ] + if p_1e3: + plots.append(("DFA + λ=1e-3", p_1e3, "h_L_norm", "g_2_norm", "acc_eval", "C2", "o", "-", 2)) + if p_1e2: + plots.append(("DFA + λ=1e-2", p_1e2, "h_L_norm", "g_2_norm", "acc_eval", "C0", "s", "-", 2)) + + # --- (a) ‖h_L‖ --- + ax = axes[0] + for label, log, hk, gk, ak, c, m, ls, lw in plots: + ax.plot([e["epoch"] for e in log], [e[hk] for e in log], + label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls) + ax.axhline(200, color="gray", linestyle=":", lw=1, label="BP ~200") + ax.set_yscale("log") + ax.set_xlabel("epoch", fontsize=10) + ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10) + ax.set_title("(a) residual stream", fontsize=11) + ax.legend(loc="lower right", fontsize=8) + ax.grid(True, which="both", alpha=0.3) + + # --- (b) ‖g_2‖ --- + ax = axes[1] + for label, log, hk, gk, ak, c, m, ls, lw in plots: + ax.plot([e["epoch"] for e in log], [e[gk] for e in log], + label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls) + ax.axhline(1e-7, color="k", linestyle="--", lw=1.2, label="floor 1e-7") + ax.set_yscale("log") + ax.set_xlabel("epoch", fontsize=10) + ax.set_ylabel(r"$\|g_2\|_2$ (log)", fontsize=10) + ax.set_title("(b) BP grad floor", fontsize=11) + ax.legend(loc="lower left", fontsize=8) + ax.grid(True, which="both", alpha=0.3) + + # --- (c) accuracy --- + ax = axes[2] + for label, log, hk, gk, ak, c, m, ls, lw in plots: + ax.plot([e["epoch"] for e in log], [e[ak] for e in log], + label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls) + ax.axhline(0.349, color="k", linestyle="--", lw=1.0, label="DFA-shallow 0.349") + ax.axhline(0.371, color="purple", linestyle=":", lw=1.0, label="2pp threshold 0.371") + ax.set_xlabel("epoch", fontsize=10) + ax.set_ylabel("test acc", fontsize=10) + ax.set_title("acc + frozen baseline + (d) threshold", fontsize=11) + ax.legend(loc="lower right", fontsize=7.5) + ax.grid(True, alpha=0.3) + ax.set_ylim(0.05, 0.55) + + fig.suptitle( + "Penalty rescue: (a) and (b) cleanly fix; (d) verdict depends on λ choice", + fontsize=11, y=1.05 + ) + fig.tight_layout() + out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_penalty_lambda_sweep.png") + fig.savefig(out_path, dpi=140, bbox_inches="tight") + print(f"Saved {out_path}") + + +if __name__ == "__main__": + main() |
