summaryrefslogtreecommitdiff
path: root/paper/figures/render_fig5_cross_arch.py
blob: 9ad9ce21fd624150b44bf499cad70c69857bcd85 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Render Figure 5: cross-architecture verdict matrix."""
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

REPO_ROOT = "/home/yurenh2/fa"

# Verdict matrix: arch x diagnostic
# 0 = ok (BP), 1 = ok-non-LN-arch, 2 = walk-back
# Columns: (a) per-block growth, (b) ||g_L|| floor, (c) drift stability, (d) frozen baseline
# Rows: ResMLP-d256, ResMLP-d512, ViT-Mini, StudentNet (no LN), CNN (BN, no LN)
arches = ["ResMLP $d{=}256$\n(terminal LN)", "ResMLP $d{=}512$\n(terminal LN)", "ViT-Mini\n(cls + LN)", "StudentNet\n(no terminal LN)", "CNN BatchNorm\n(no terminal LN)"]
diags = ["(a) scale", "(b) ${\\|g\\|}$ floor", "(c) drift", "(d) frozen"]

# DFA verdicts on each
# 1 = fires (walked back); 0 = passes
dfa = np.array([
    [1, 1, 0, 1],   # ResMLP d256: (a) fires, (b) fires, (c) noise sub-mode, (d) fires
    [1, 1, 0, 1],   # ResMLP d512: same pattern
    [1, 1, 0, 1],   # ViT-Mini: same pattern
    [1, 0, 0, 0],   # StudentNet: only (a) fires; (b) NEVER
    [1, 0, 0, 0],   # CNN BN: only (a) fires; (b) NEVER (the killer (b)-is-LN-specific finding)
])

bp = np.zeros_like(dfa)  # BP: passes everywhere

fig, axes = plt.subplots(1, 2, figsize=(11, 4.2))

for ax, mat, title in [(axes[0], bp, "BP-trained: protocol passes"),
                       (axes[1], dfa, "DFA-trained: protocol verdict by architecture")]:
    cmap = matplotlib.colors.ListedColormap(["#4682b4", "#cc4444"])
    im = ax.imshow(mat, cmap=cmap, aspect="auto", vmin=0, vmax=1)
    for i in range(mat.shape[0]):
        for j in range(mat.shape[1]):
            txt = "WB" if mat[i, j] == 1 else "$\\checkmark$"
            ax.text(j, i, txt, ha="center", va="center", color="white", fontsize=11, fontweight="bold")
    ax.set_xticks(range(len(diags)))
    ax.set_xticklabels(diags, fontsize=9)
    ax.set_yticks(range(len(arches)))
    ax.set_yticklabels(arches, fontsize=9)
    ax.set_title(title, fontsize=10)

# Highlight the key finding — place well below the multiline x/y tick labels to avoid overlap
axes[1].text(0.5, -1.55, "Key finding: diagnostic (b) BP-grad-floor fires only on terminal-LN architectures.\n"
             "Across the 5 architecture cases tested, (b) is restricted to the with-terminal-LN family.",
             ha="center", fontsize=9, style="italic", transform=axes[1].transAxes)

plt.tight_layout()
out = os.path.join(REPO_ROOT, "paper/figures/fig5_cross_arch_summary.pdf")
plt.savefig(out, bbox_inches="tight", dpi=200)
print(f"Saved {out}")