summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 00:03:49 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 00:03:49 -0500
commit76edf529be1b8aa8813ce380d104eaa424a3dc1d (patch)
treeb38df942983d246379c3121d7df6f4266c8c343e /protocol
parent9d0e4901a82763ea3ebc57eea152a730330d4991 (diff)
Add penalty λ sweep figure: shows λ-dependence of (d) verdict
3-panel figure: vanilla DFA + penalty at λ=1e-3 (green) + penalty at λ=1e-2 (blue): (a) ‖h_L‖: vanilla 4e8, both penalties ~4e4 (similar) (b) ‖g_2‖: vanilla 5e-10, penalties 7e-7 to 1e-6 (above floor) (c) acc: vanilla 0.31, λ=1e-2 0.36, λ=1e-3 0.37; horizontal lines at DFA-shallow 0.349 and 2pp threshold 0.371 Visual: at λ=1e-3 the test acc curve crosses ABOVE the 2pp threshold line; at λ=1e-2 it stays below. This is the (d) lambda-dependence finding from the round 18 follow-up.
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/temporal_penalty_vs_vanilla.py114
1 files changed, 114 insertions, 0 deletions
diff --git a/protocol/examples/temporal_penalty_vs_vanilla.py b/protocol/examples/temporal_penalty_vs_vanilla.py
new file mode 100644
index 0000000..4c37418
--- /dev/null
+++ b/protocol/examples/temporal_penalty_vs_vanilla.py
@@ -0,0 +1,114 @@
+"""
+Temporal trajectory of the protocol's (a) and (b) diagnostics under
+DFA + residual-branch penalty, vs vanilla DFA.
+
+Uses the existing per-epoch logs:
+ - vanilla DFA: results/snapshot_evolution_v2/snapshot_evolution_s42.json
+ - penalized DFA (3 lam values): results/dfa_residual_penalty/dfa_pen_lam{0.001,0.01,0.1}_s42.json
+
+Run:
+ python -m protocol.examples.temporal_penalty_vs_vanilla
+"""
+import os
+import sys
+import json
+
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+
+
+def load_vanilla_log():
+ with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f:
+ d = json.load(f)
+ log = d["dfa_log"]
+ return [
+ {"epoch": e["epoch"],
+ "h_L": e["hidden_norms"][-1],
+ "g_2": e["bp_grad_norms_per_sample_med"][2],
+ "acc": e["acc_eval"]}
+ for e in log
+ ]
+
+
+def load_penalty_log(lam):
+ path = os.path.join(REPO_ROOT, f"results/dfa_residual_penalty/dfa_pen_lam{lam}_s42.json")
+ if not os.path.exists(path):
+ return None
+ with open(path) as f:
+ d = json.load(f)
+ return d["log"]
+
+
+def main():
+ vanilla = load_vanilla_log()
+ p_1e3 = load_penalty_log("0.001")
+ p_1e2 = load_penalty_log("0.01")
+ # p_1e1 = load_penalty_log("0.1") # was killed mid-run
+
+ fig, axes = plt.subplots(1, 3, figsize=(13, 4.0), gridspec_kw={"wspace": 0.35})
+
+ plots = [
+ ("vanilla DFA", vanilla, "h_L", "g_2", "acc", "C3", None, "-", 2),
+ ]
+ if p_1e3:
+ plots.append(("DFA + λ=1e-3", p_1e3, "h_L_norm", "g_2_norm", "acc_eval", "C2", "o", "-", 2))
+ if p_1e2:
+ plots.append(("DFA + λ=1e-2", p_1e2, "h_L_norm", "g_2_norm", "acc_eval", "C0", "s", "-", 2))
+
+ # --- (a) ‖h_L‖ ---
+ ax = axes[0]
+ for label, log, hk, gk, ak, c, m, ls, lw in plots:
+ ax.plot([e["epoch"] for e in log], [e[hk] for e in log],
+ label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls)
+ ax.axhline(200, color="gray", linestyle=":", lw=1, label="BP ~200")
+ ax.set_yscale("log")
+ ax.set_xlabel("epoch", fontsize=10)
+ ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10)
+ ax.set_title("(a) residual stream", fontsize=11)
+ ax.legend(loc="lower right", fontsize=8)
+ ax.grid(True, which="both", alpha=0.3)
+
+ # --- (b) ‖g_2‖ ---
+ ax = axes[1]
+ for label, log, hk, gk, ak, c, m, ls, lw in plots:
+ ax.plot([e["epoch"] for e in log], [e[gk] for e in log],
+ label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls)
+ ax.axhline(1e-7, color="k", linestyle="--", lw=1.2, label="floor 1e-7")
+ ax.set_yscale("log")
+ ax.set_xlabel("epoch", fontsize=10)
+ ax.set_ylabel(r"$\|g_2\|_2$ (log)", fontsize=10)
+ ax.set_title("(b) BP grad floor", fontsize=11)
+ ax.legend(loc="lower left", fontsize=8)
+ ax.grid(True, which="both", alpha=0.3)
+
+ # --- (c) accuracy ---
+ ax = axes[2]
+ for label, log, hk, gk, ak, c, m, ls, lw in plots:
+ ax.plot([e["epoch"] for e in log], [e[ak] for e in log],
+ label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls)
+ ax.axhline(0.349, color="k", linestyle="--", lw=1.0, label="DFA-shallow 0.349")
+ ax.axhline(0.371, color="purple", linestyle=":", lw=1.0, label="2pp threshold 0.371")
+ ax.set_xlabel("epoch", fontsize=10)
+ ax.set_ylabel("test acc", fontsize=10)
+ ax.set_title("acc + frozen baseline + (d) threshold", fontsize=11)
+ ax.legend(loc="lower right", fontsize=7.5)
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(0.05, 0.55)
+
+ fig.suptitle(
+ "Penalty rescue: (a) and (b) cleanly fix; (d) verdict depends on λ choice",
+ fontsize=11, y=1.05
+ )
+ fig.tight_layout()
+ out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_penalty_lambda_sweep.png")
+ fig.savefig(out_path, dpi=140, bbox_inches="tight")
+ print(f"Saved {out_path}")
+
+
+if __name__ == "__main__":
+ main()