summaryrefslogtreecommitdiff
path: root/protocol/examples/plot_audit_table.py
blob: d931eb922756db6e0523187658e7e3e2aa6fe818 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Plot the §2 5-method audit table as a paper-ready figure.

Layout: 4 panels (one per diagnostic), each is a horizontal bar chart of
3-seed mean values per method, with the threshold line drawn. Methods
sorted top-to-bottom by ascending acc.

Run:
    python -m protocol.examples.plot_audit_table
"""
import os
import sys
import json
import math

import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

REPO_ROOT = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.insert(0, REPO_ROOT)

AUDIT_PATH = os.path.join(REPO_ROOT, "results/protocol_audit/audit_table_s42_s123_s456.json")


def max_per_block_growth(h):
    if len(h) < 2:
        return 1.0
    return max(h[i + 1] / max(h[i], 1e-30) for i in range(len(h) - 1))


def main():
    with open(AUDIT_PATH) as f:
        data = json.load(f)

    methods = ["bp", "ep", "dfa", "credit_bridge", "state_bridge"]
    method_labels = {
        "bp": "BP",
        "ep": "EP",
        "dfa": "DFA",
        "credit_bridge": "Credit\nBridge",
        "state_bridge": "State\nBridge",
    }

    # Aggregate per-method 3-seed values
    agg = {m: {"acc": [], "max_per_block": [], "g_L": [], "stab": []} for m in methods}
    for row in data["summary"]:
        m = row["method"]
        if m not in agg:
            continue
        report = data["reports"][f"{m}_s{row['seed']}"]
        agg[m]["acc"].append(row["acc"])
        agg[m]["max_per_block"].append(max_per_block_growth(report["residual_norms"]))
        agg[m]["g_L"].append(report["bp_grad_norms"][-1])
        agg[m]["stab"].append(row["stability"])

    # Healthy/degenerate color coding
    healthy = {"bp", "ep"}
    colors = ["#4682b4" if m in healthy else "#cc4444" for m in methods]

    fig, axes = plt.subplots(1, 4, figsize=(16, 4.5), gridspec_kw={"wspace": 0.5})

    y = np.arange(len(methods))

    # --- (a) max per-block growth --- #
    ax = axes[0]
    means = [np.mean(agg[m]["max_per_block"]) for m in methods]
    stds = [np.std(agg[m]["max_per_block"]) for m in methods]
    ax.barh(y, means, xerr=stds, color=colors, alpha=0.85, capsize=3)
    ax.set_xscale("log")
    ax.axvline(50, color="k", linestyle="--", lw=1.2, label="threshold 50×")
    ax.set_yticks(y)
    ax.set_yticklabels([method_labels[m] for m in methods], fontsize=10)
    ax.set_xlabel("max per-block growth (log)", fontsize=10)
    ax.set_title("(a) ‖h_l‖ explosion", fontsize=11)
    ax.legend(loc="lower right", fontsize=8)
    ax.grid(True, axis="x", which="both", alpha=0.3)

    # --- (b) g_L floor --- #
    ax = axes[1]
    means = [np.mean(agg[m]["g_L"]) for m in methods]
    stds = [np.std(agg[m]["g_L"]) for m in methods]
    ax.barh(y, means, xerr=stds, color=colors, alpha=0.85, capsize=3)
    ax.set_xscale("log")
    ax.axvline(1e-7, color="k", linestyle="--", lw=1.2, label="floor 1e-7")
    ax.set_yticks(y)
    ax.set_yticklabels([])
    ax.set_xlabel("‖g_L‖ (log)", fontsize=10)
    ax.set_title("(b) BP grad at floor", fontsize=11)
    ax.legend(loc="lower right", fontsize=8)
    ax.grid(True, axis="x", which="both", alpha=0.3)

    # --- (c) stability --- #
    ax = axes[2]
    means = [np.mean(agg[m]["stab"]) for m in methods]
    stds = [np.std(agg[m]["stab"]) for m in methods]
    ax.barh(y, means, xerr=stds, color=colors, alpha=0.85, capsize=3)
    ax.axvline(0.30, color="k", linestyle="--", lw=1.2, label="ceiling 0.30")
    ax.set_yticks(y)
    ax.set_yticklabels([])
    ax.set_xlabel("cross-batch stability", fontsize=10)
    ax.set_title("(c) reference drift", fontsize=11)
    ax.legend(loc="upper right", fontsize=8)
    ax.grid(True, axis="x", alpha=0.3)
    ax.set_xlim(-0.2, 1.05)

    # --- accuracy + frozen baseline --- #
    ax = axes[3]
    accs = [np.mean(agg[m]["acc"]) for m in methods]
    stds = [np.std(agg[m]["acc"]) for m in methods]
    ax.barh(y, accs, xerr=stds, color=colors, alpha=0.85, capsize=3)
    ax.axvline(0.349, color="k", linestyle="--", lw=1.2, label="frozen baseline 0.349")
    ax.set_yticks(y)
    ax.set_yticklabels([])
    ax.set_xlabel("test accuracy", fontsize=10)
    ax.set_title("(d) headline acc vs frozen", fontsize=11)
    ax.legend(loc="lower right", fontsize=8)
    ax.grid(True, axis="x", alpha=0.3)
    ax.set_xlim(0, 0.7)

    fig.suptitle(
        "5-method audit on 4-block d=256 ResMLP CIFAR-10 (3-seed mean ± std)",
        fontsize=13, y=1.02
    )
    fig.tight_layout()
    out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_audit_5method.png")
    fig.savefig(out_path, dpi=140, bbox_inches="tight")
    print(f"Saved {out_path}")


if __name__ == "__main__":
    main()