summaryrefslogtreecommitdiff
path: root/paper/figures/render_fig2_decision_utility.py
blob: ba1a1485670fa628f2ecad1fbf8739af006fd6b2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""Render Figure 2: decision-utility ablation as a heatmap."""
import os
import json
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

REPO_ROOT = "/home/yurenh2/fa"
with open(os.path.join(REPO_ROOT, "results/protocol_audit/ablation_decision_utility.json")) as f:
    data = json.load(f)

methods = ["bp", "ep", "dfa", "credit_bridge", "state_bridge"]
method_labels = ["BP", "EP", "DFA", "Credit Bridge", "State Bridge"]
strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"]
strategy_labels = ["S0\nacc only", "S1\n+$\\Gamma$\n(field std)", "S2\n+(a)\nscale", "S3\n+(b)\nfloor", "S4\n+(c)\nstability", "S5\n+(d)\nfrozen", "S$_\\mathrm{full}$\nfull"]

# Build a verdict matrix: 1 = walk back (red), 0 = trustworthy (blue)
mat = np.zeros((len(methods), len(strategies)))
for i, m in enumerate(methods):
    for j, s in enumerate(strategies):
        v = data["table"][m][s]
        if "WALK" in v:
            mat[i, j] = 1.0

fig, ax = plt.subplots(figsize=(7, 3.2))
cmap = matplotlib.colors.ListedColormap(["#4682b4", "#cc4444"])
im = ax.imshow(mat, cmap=cmap, aspect="auto", vmin=0, vmax=1)

# Annotate each cell
for i in range(len(methods)):
    for j in range(len(strategies)):
        v = data["table"][methods[i]][strategies[j]]
        txt = "WB" if "WALK" in v else "$\\checkmark$"
        ax.text(j, i, txt, ha="center", va="center",
                color="white" if mat[i, j] > 0.5 else "white", fontsize=10, fontweight="bold")

ax.set_xticks(range(len(strategies)))
ax.set_xticklabels(strategy_labels, fontsize=8)
ax.set_yticks(range(len(methods)))
ax.set_yticklabels(method_labels, fontsize=10)
ax.set_xlabel("reporting strategy", fontsize=10)

# Color legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor="#4682b4", label="$\\checkmark$ trustworthy"),
    Patch(facecolor="#cc4444", label="WB walked back"),
]
ax.legend(handles=legend_elements, loc="upper center", bbox_to_anchor=(0.5, -0.25), ncol=2, fontsize=9, frameon=False)

ax.set_title("Decision-utility ablation: 7 reporting strategies vs 5 methods\nfield-standard pair (S0, S1) walks back 0/5; full protocol walks back 3/5", fontsize=10)
plt.tight_layout()
out = os.path.join(REPO_ROOT, "paper/figures/fig2_decision_utility.pdf")
plt.savefig(out, bbox_inches="tight", dpi=200)
print(f"Saved {out}")