""" Plot the §2 5-method audit table as a paper-ready figure. Layout: 4 panels (one per diagnostic), each is a horizontal bar chart of 3-seed mean values per method, with the threshold line drawn. Methods sorted top-to-bottom by ascending acc. Run: python -m protocol.examples.plot_audit_table """ import os import sys import json import math import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) sys.path.insert(0, REPO_ROOT) AUDIT_PATH = os.path.join(REPO_ROOT, "results/protocol_audit/audit_table_s42_s123_s456.json") def max_per_block_growth(h): if len(h) < 2: return 1.0 return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1)) def main(): with open(AUDIT_PATH) as f: data = json.load(f) methods = ["bp", "ep", "dfa", "credit_bridge", "state_bridge"] method_labels = { "bp": "BP", "ep": "EP", "dfa": "DFA", "credit_bridge": "Credit\nBridge", "state_bridge": "State\nBridge", } # Aggregate per-method 3-seed values agg = {m: {"acc": [], "max_per_block": [], "g_L": [], "stab": []} for m in methods} for row in data["summary"]: m = row["method"] if m not in agg: continue report = data["reports"][f"{m}_s{row['seed']}"] agg[m]["acc"].append(row["acc"]) agg[m]["max_per_block"].append(max_per_block_growth(report["residual_norms"])) agg[m]["g_L"].append(report["bp_grad_norms"][-1]) agg[m]["stab"].append(row["stability"]) # Healthy/degenerate color coding healthy = {"bp", "ep"} colors = ["#4682b4" if m in healthy else "#cc4444" for m in methods] fig, axes = plt.subplots(1, 4, figsize=(16, 4.5), gridspec_kw={"wspace": 0.5}) y = np.arange(len(methods)) # --- (a) max per-block growth --- # ax = axes[0] means = [np.mean(agg[m]["max_per_block"]) for m in methods] stds = [np.std(agg[m]["max_per_block"]) for m in methods] ax.barh(y, means, xerr=stds, color=colors, alpha=0.85, capsize=3) ax.set_xscale("log") ax.axvline(50, color="k", linestyle="--", lw=1.2, label="threshold 50×") ax.set_yticks(y) ax.set_yticklabels([method_labels[m] for m in methods], fontsize=10) ax.set_xlabel("max per-block growth (log)", fontsize=10) ax.set_title("(a) ‖h_l‖ explosion", fontsize=11) ax.legend(loc="lower right", fontsize=8) ax.grid(True, axis="x", which="both", alpha=0.3) # --- (b) g_L floor --- # ax = axes[1] means = [np.mean(agg[m]["g_L"]) for m in methods] stds = [np.std(agg[m]["g_L"]) for m in methods] ax.barh(y, means, xerr=stds, color=colors, alpha=0.85, capsize=3) ax.set_xscale("log") ax.axvline(1e-7, color="k", linestyle="--", lw=1.2, label="floor 1e-7") ax.set_yticks(y) ax.set_yticklabels([]) ax.set_xlabel("‖g_L‖ (log)", fontsize=10) ax.set_title("(b) BP grad at floor", fontsize=11) ax.legend(loc="lower right", fontsize=8) ax.grid(True, axis="x", which="both", alpha=0.3) # --- (c) stability --- # ax = axes[2] means = [np.mean(agg[m]["stab"]) for m in methods] stds = [np.std(agg[m]["stab"]) for m in methods] ax.barh(y, means, xerr=stds, color=colors, alpha=0.85, capsize=3) ax.axvline(0.30, color="k", linestyle="--", lw=1.2, label="ceiling 0.30") ax.set_yticks(y) ax.set_yticklabels([]) ax.set_xlabel("cross-batch stability", fontsize=10) ax.set_title("(c) reference drift", fontsize=11) ax.legend(loc="upper right", fontsize=8) ax.grid(True, axis="x", alpha=0.3) ax.set_xlim(-0.2, 1.05) # --- accuracy + frozen baseline --- # ax = axes[3] accs = [np.mean(agg[m]["acc"]) for m in methods] stds = [np.std(agg[m]["acc"]) for m in methods] ax.barh(y, accs, xerr=stds, color=colors, alpha=0.85, capsize=3) ax.axvline(0.349, color="k", linestyle="--", lw=1.2, label="frozen baseline 0.349") ax.set_yticks(y) ax.set_yticklabels([]) ax.set_xlabel("test accuracy", fontsize=10) ax.set_title("(d) headline acc vs frozen", fontsize=11) ax.legend(loc="lower right", fontsize=8) ax.grid(True, axis="x", alpha=0.3) ax.set_xlim(0, 0.7) fig.suptitle( "5-method audit on 4-block d=256 ResMLP CIFAR-10 (3-seed mean ± std)", fontsize=13, y=1.02 ) fig.tight_layout() out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_audit_5method.png") fig.savefig(out_path, dpi=140, bbox_inches="tight") print(f"Saved {out_path}") if __name__ == "__main__": main()