"""Render Figure 2: decision-utility ablation as a heatmap.""" import os import json import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np REPO_ROOT = "/home/yurenh2/fa" with open(os.path.join(REPO_ROOT, "results/protocol_audit/ablation_decision_utility.json")) as f: data = json.load(f) methods = ["bp", "ep", "dfa", "credit_bridge", "state_bridge"] method_labels = ["BP", "EP", "DFA", "Credit Bridge", "State Bridge"] strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"] strategy_labels = ["S0\nacc only", "S1\n+$\\Gamma$\n(field std)", "S2\n+(a)\nscale", "S3\n+(b)\nfloor", "S4\n+(c)\nstability", "S5\n+(d)\nfrozen", "S$_\\mathrm{full}$\nfull"] # Build a verdict matrix: 1 = walk back (red), 0 = trustworthy (blue) mat = np.zeros((len(methods), len(strategies))) for i, m in enumerate(methods): for j, s in enumerate(strategies): v = data["table"][m][s] if "WALK" in v: mat[i, j] = 1.0 fig, ax = plt.subplots(figsize=(7, 3.2)) cmap = matplotlib.colors.ListedColormap(["#4682b4", "#cc4444"]) im = ax.imshow(mat, cmap=cmap, aspect="auto", vmin=0, vmax=1) # Annotate each cell for i in range(len(methods)): for j in range(len(strategies)): v = data["table"][methods[i]][strategies[j]] txt = "WB" if "WALK" in v else "$\\checkmark$" ax.text(j, i, txt, ha="center", va="center", color="white" if mat[i, j] > 0.5 else "white", fontsize=10, fontweight="bold") ax.set_xticks(range(len(strategies))) ax.set_xticklabels(strategy_labels, fontsize=8) ax.set_yticks(range(len(methods))) ax.set_yticklabels(method_labels, fontsize=10) ax.set_xlabel("reporting strategy", fontsize=10) # Color legend from matplotlib.patches import Patch legend_elements = [ Patch(facecolor="#4682b4", label="$\\checkmark$ trustworthy"), Patch(facecolor="#cc4444", label="WB walked back"), ] ax.legend(handles=legend_elements, loc="upper center", bbox_to_anchor=(0.5, -0.25), ncol=2, fontsize=9, frameon=False) ax.set_title("Decision-utility ablation: 7 reporting strategies vs 5 methods\nfield-standard pair (S0, S1) walks back 0/5; full protocol walks back 3/5", fontsize=10) plt.tight_layout() out = os.path.join(REPO_ROOT, "paper/figures/fig2_decision_utility.pdf") plt.savefig(out, bbox_inches="tight", dpi=200) print(f"Saved {out}")