summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:29:13 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:29:13 -0500
commit9f7424553392e2f4b9f6e90a71b3b6e1e52f303f (patch)
treef3bba04c1490d1a722a6a42507a68b4c5a5217d6 /protocol
parent665c9bb4ab3a5126c6fc191eecf42be7b703eb0c (diff)
Add §2/§3 hero figure: 5-method audit horizontal bar chart
4-panel layout (one per diagnostic), 5 methods sorted bottom-to-top by ascending accuracy, color-coded healthy (BP/EP, blue) vs degenerate (DFA/SB/CB, red), with threshold lines drawn: (a) max per-block growth (log scale, threshold 50x) (b) ||g_L|| (log scale, floor 1e-7) (c) cross-batch stability (linear, ceiling 0.30) (d) headline acc (linear, frozen baseline 0.349) The visual layout makes it immediately obvious that: - (a) and (b) cleanly split healthy from degenerate (4-7 OOM gap) - (c) is bimodal and doesn't cleanly split — confirms it's a sub-mode discriminator, not a primary detector - (d) shows BP above the frozen baseline by ~25 pp while DFA/CB/SB are at or below it
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/plot_audit_table.py135
1 files changed, 135 insertions, 0 deletions
diff --git a/protocol/examples/plot_audit_table.py b/protocol/examples/plot_audit_table.py
new file mode 100644
index 0000000..d931eb9
--- /dev/null
+++ b/protocol/examples/plot_audit_table.py
@@ -0,0 +1,135 @@
+"""
+Plot the §2 5-method audit table as a paper-ready figure.
+
+Layout: 4 panels (one per diagnostic), each is a horizontal bar chart of
+3-seed mean values per method, with the threshold line drawn. Methods
+sorted top-to-bottom by ascending acc.
+
+Run:
+ python -m protocol.examples.plot_audit_table
+"""
+import os
+import sys
+import json
+import math
+
+import numpy as np
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.insert(0, REPO_ROOT)
+
+AUDIT_PATH = os.path.join(REPO_ROOT, "results/protocol_audit/audit_table_s42_s123_s456.json")
+
+
+def max_per_block_growth(h):
+ if len(h) < 2:
+ return 1.0
+ return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))
+
+
+def main():
+ with open(AUDIT_PATH) as f:
+ data = json.load(f)
+
+ methods = ["bp", "ep", "dfa", "credit_bridge", "state_bridge"]
+ method_labels = {
+ "bp": "BP",
+ "ep": "EP",
+ "dfa": "DFA",
+ "credit_bridge": "Credit\nBridge",
+ "state_bridge": "State\nBridge",
+ }
+
+ # Aggregate per-method 3-seed values
+ agg = {m: {"acc": [], "max_per_block": [], "g_L": [], "stab": []} for m in methods}
+ for row in data["summary"]:
+ m = row["method"]
+ if m not in agg:
+ continue
+ report = data["reports"][f"{m}_s{row['seed']}"]
+ agg[m]["acc"].append(row["acc"])
+ agg[m]["max_per_block"].append(max_per_block_growth(report["residual_norms"]))
+ agg[m]["g_L"].append(report["bp_grad_norms"][-1])
+ agg[m]["stab"].append(row["stability"])
+
+ # Healthy/degenerate color coding
+ healthy = {"bp", "ep"}
+ colors = ["#4682b4" if m in healthy else "#cc4444" for m in methods]
+
+ fig, axes = plt.subplots(1, 4, figsize=(16, 4.5), gridspec_kw={"wspace": 0.5})
+
+ y = np.arange(len(methods))
+
+ # --- (a) max per-block growth --- #
+ ax = axes[0]
+ means = [np.mean(agg[m]["max_per_block"]) for m in methods]
+ stds = [np.std(agg[m]["max_per_block"]) for m in methods]
+ ax.barh(y, means, xerr=stds, color=colors, alpha=0.85, capsize=3)
+ ax.set_xscale("log")
+ ax.axvline(50, color="k", linestyle="--", lw=1.2, label="threshold 50×")
+ ax.set_yticks(y)
+ ax.set_yticklabels([method_labels[m] for m in methods], fontsize=10)
+ ax.set_xlabel("max per-block growth (log)", fontsize=10)
+ ax.set_title("(a) ‖h_l‖ explosion", fontsize=11)
+ ax.legend(loc="lower right", fontsize=8)
+ ax.grid(True, axis="x", which="both", alpha=0.3)
+
+ # --- (b) g_L floor --- #
+ ax = axes[1]
+ means = [np.mean(agg[m]["g_L"]) for m in methods]
+ stds = [np.std(agg[m]["g_L"]) for m in methods]
+ ax.barh(y, means, xerr=stds, color=colors, alpha=0.85, capsize=3)
+ ax.set_xscale("log")
+ ax.axvline(1e-7, color="k", linestyle="--", lw=1.2, label="floor 1e-7")
+ ax.set_yticks(y)
+ ax.set_yticklabels([])
+ ax.set_xlabel("‖g_L‖ (log)", fontsize=10)
+ ax.set_title("(b) BP grad at floor", fontsize=11)
+ ax.legend(loc="lower right", fontsize=8)
+ ax.grid(True, axis="x", which="both", alpha=0.3)
+
+ # --- (c) stability --- #
+ ax = axes[2]
+ means = [np.mean(agg[m]["stab"]) for m in methods]
+ stds = [np.std(agg[m]["stab"]) for m in methods]
+ ax.barh(y, means, xerr=stds, color=colors, alpha=0.85, capsize=3)
+ ax.axvline(0.30, color="k", linestyle="--", lw=1.2, label="ceiling 0.30")
+ ax.set_yticks(y)
+ ax.set_yticklabels([])
+ ax.set_xlabel("cross-batch stability", fontsize=10)
+ ax.set_title("(c) reference drift", fontsize=11)
+ ax.legend(loc="upper right", fontsize=8)
+ ax.grid(True, axis="x", alpha=0.3)
+ ax.set_xlim(-0.2, 1.05)
+
+ # --- accuracy + frozen baseline --- #
+ ax = axes[3]
+ accs = [np.mean(agg[m]["acc"]) for m in methods]
+ stds = [np.std(agg[m]["acc"]) for m in methods]
+ ax.barh(y, accs, xerr=stds, color=colors, alpha=0.85, capsize=3)
+ ax.axvline(0.349, color="k", linestyle="--", lw=1.2, label="frozen baseline 0.349")
+ ax.set_yticks(y)
+ ax.set_yticklabels([])
+ ax.set_xlabel("test accuracy", fontsize=10)
+ ax.set_title("(d) headline acc vs frozen", fontsize=11)
+ ax.legend(loc="lower right", fontsize=8)
+ ax.grid(True, axis="x", alpha=0.3)
+ ax.set_xlim(0, 0.7)
+
+ fig.suptitle(
+ "5-method audit on 4-block d=256 ResMLP CIFAR-10 (3-seed mean ± std)",
+ fontsize=13, y=1.02
+ )
+ fig.tight_layout()
+ out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_audit_5method.png")
+ fig.savefig(out_path, dpi=140, bbox_inches="tight")
+ print(f"Saved {out_path}")
+
+
+if __name__ == "__main__":
+ main()