diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:29:13 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:29:13 -0500 |
| commit | 9f7424553392e2f4b9f6e90a71b3b6e1e52f303f (patch) | |
| tree | f3bba04c1490d1a722a6a42507a68b4c5a5217d6 /protocol | |
| parent | 665c9bb4ab3a5126c6fc191eecf42be7b703eb0c (diff) | |
Add §2/§3 hero figure: 5-method audit horizontal bar chart
4-panel layout (one per diagnostic), 5 methods sorted bottom-to-top by
ascending accuracy, color-coded healthy (BP/EP, blue) vs degenerate
(DFA/SB/CB, red), with threshold lines drawn:
(a) max per-block growth (log scale, threshold 50x)
(b) ||g_L|| (log scale, floor 1e-7)
(c) cross-batch stability (linear, ceiling 0.30)
(d) headline acc (linear, frozen baseline 0.349)
The visual layout makes it immediately obvious that:
- (a) and (b) cleanly split healthy from degenerate (4-7 OOM gap)
- (c) is bimodal and doesn't cleanly split — confirms it's a sub-mode
discriminator, not a primary detector
- (d) shows BP above the frozen baseline by ~25 pp while DFA/CB/SB
are at or below it
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/examples/plot_audit_table.py | 135 |
1 files changed, 135 insertions, 0 deletions
diff --git a/protocol/examples/plot_audit_table.py b/protocol/examples/plot_audit_table.py new file mode 100644 index 0000000..d931eb9 --- /dev/null +++ b/protocol/examples/plot_audit_table.py @@ -0,0 +1,135 @@ +""" +Plot the §2 5-method audit table as a paper-ready figure. + +Layout: 4 panels (one per diagnostic), each is a horizontal bar chart of +3-seed mean values per method, with the threshold line drawn. Methods +sorted top-to-bottom by ascending acc. + +Run: + python -m protocol.examples.plot_audit_table +""" +import os +import sys +import json +import math + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.insert(0, REPO_ROOT) + +AUDIT_PATH = os.path.join(REPO_ROOT, "results/protocol_audit/audit_table_s42_s123_s456.json") + + +def max_per_block_growth(h): + if len(h) < 2: + return 1.0 + return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) + + +def main(): + with open(AUDIT_PATH) as f: + data = json.load(f) + + methods = ["bp", "ep", "dfa", "credit_bridge", "state_bridge"] + method_labels = { + "bp": "BP", + "ep": "EP", + "dfa": "DFA", + "credit_bridge": "Credit\nBridge", + "state_bridge": "State\nBridge", + } + + # Aggregate per-method 3-seed values + agg = {m: {"acc": [], "max_per_block": [], "g_L": [], "stab": []} for m in methods} + for row in data["summary"]: + m = row["method"] + if m not in agg: + continue + report = data["reports"][f"{m}_s{row['seed']}"] + agg[m]["acc"].append(row["acc"]) + agg[m]["max_per_block"].append(max_per_block_growth(report["residual_norms"])) + agg[m]["g_L"].append(report["bp_grad_norms"][-1]) + agg[m]["stab"].append(row["stability"]) + + # Healthy/degenerate color coding + healthy = {"bp", "ep"} + colors = ["#4682b4" if m in healthy else "#cc4444" for m in methods] + + fig, axes = plt.subplots(1, 4, figsize=(16, 4.5), gridspec_kw={"wspace": 0.5}) + + y = np.arange(len(methods)) + + # --- (a) max per-block growth --- # + ax = axes[0] + means = [np.mean(agg[m]["max_per_block"]) for m in methods] + stds = [np.std(agg[m]["max_per_block"]) for m in methods] + ax.barh(y, means, xerr=stds, color=colors, alpha=0.85, capsize=3) + ax.set_xscale("log") + ax.axvline(50, color="k", linestyle="--", lw=1.2, label="threshold 50×") + ax.set_yticks(y) + ax.set_yticklabels([method_labels[m] for m in methods], fontsize=10) + ax.set_xlabel("max per-block growth (log)", fontsize=10) + ax.set_title("(a) ‖h_l‖ explosion", fontsize=11) + ax.legend(loc="lower right", fontsize=8) + ax.grid(True, axis="x", which="both", alpha=0.3) + + # --- (b) g_L floor --- # + ax = axes[1] + means = [np.mean(agg[m]["g_L"]) for m in methods] + stds = [np.std(agg[m]["g_L"]) for m in methods] + ax.barh(y, means, xerr=stds, color=colors, alpha=0.85, capsize=3) + ax.set_xscale("log") + ax.axvline(1e-7, color="k", linestyle="--", lw=1.2, label="floor 1e-7") + ax.set_yticks(y) + ax.set_yticklabels([]) + ax.set_xlabel("‖g_L‖ (log)", fontsize=10) + ax.set_title("(b) BP grad at floor", fontsize=11) + ax.legend(loc="lower right", fontsize=8) + ax.grid(True, axis="x", which="both", alpha=0.3) + + # --- (c) stability --- # + ax = axes[2] + means = [np.mean(agg[m]["stab"]) for m in methods] + stds = [np.std(agg[m]["stab"]) for m in methods] + ax.barh(y, means, xerr=stds, color=colors, alpha=0.85, capsize=3) + ax.axvline(0.30, color="k", linestyle="--", lw=1.2, label="ceiling 0.30") + ax.set_yticks(y) + ax.set_yticklabels([]) + ax.set_xlabel("cross-batch stability", fontsize=10) + ax.set_title("(c) reference drift", fontsize=11) + ax.legend(loc="upper right", fontsize=8) + ax.grid(True, axis="x", alpha=0.3) + ax.set_xlim(-0.2, 1.05) + + # --- accuracy + frozen baseline --- # + ax = axes[3] + accs = [np.mean(agg[m]["acc"]) for m in methods] + stds = [np.std(agg[m]["acc"]) for m in methods] + ax.barh(y, accs, xerr=stds, color=colors, alpha=0.85, capsize=3) + ax.axvline(0.349, color="k", linestyle="--", lw=1.2, label="frozen baseline 0.349") + ax.set_yticks(y) + ax.set_yticklabels([]) + ax.set_xlabel("test accuracy", fontsize=10) + ax.set_title("(d) headline acc vs frozen", fontsize=11) + ax.legend(loc="lower right", fontsize=8) + ax.grid(True, axis="x", alpha=0.3) + ax.set_xlim(0, 0.7) + + fig.suptitle( + "5-method audit on 4-block d=256 ResMLP CIFAR-10 (3-seed mean ± std)", + fontsize=13, y=1.02 + ) + fig.tight_layout() + out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_audit_5method.png") + fig.savefig(out_path, dpi=140, bbox_inches="tight") + print(f"Saved {out_path}") + + +if __name__ == "__main__": + main() |
