summaryrefslogtreecommitdiff
path: root/paper/figures/render_fig5_cross_arch.py
blob: 470313dbef0c5c10b1a11647171c5d82043a26e8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Render Figure 5: cross-architecture verdict matrix.

Verdict encoding:
  0 = passes (✓, blue)
  1 = walked back (WB, red)
  2 = not measured for this architecture (—, gray)

Sources:
  ResMLP d=256 row: results/protocol_audit/audit_table_s42_s123_s456.json
                    + temporal_evolution_s{42,123,456}.json (no-LN row)
  ResMLP d=512 row: results/protocol_audit/audit_d512_3seed.json
  ViT-Mini row: snapshot_vit_v1 + ViT walk-back memory (acc 0.237 vs frozen ~0.255-0.261)
  no-terminal-LN ResMLP row: results/snapshot_no_outln_v1 (3-seed acc 0.327 ± 0.012 vs
                              proxy frozen baseline 0.349 ± 0.002 → fails (d) by 2.2 pp)
  CNN BatchNorm row: results/protocol_audit/audit_cnn_3seed.json
                     (no CNN frozen baseline → (c)+(d) not measured)
"""
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

REPO_ROOT = "/home/yurenh2/fa"

# Rows: ResMLP-d256, ResMLP-d512, ViT-Mini, no-terminal-LN ResMLP-d256, CNN (BN, no LN)
arches = ["ResMLP $d{=}256$\n(terminal LN)",
          "ResMLP $d{=}512$\n(terminal LN)",
          "ViT-Mini\n(cls + LN)",
          "ResMLP $d{=}256$\n(no terminal LN)",
          "CNN BatchNorm\n(no terminal LN)"]
diags = ["(a) scale", "(b) ${\\|g\\|}$ floor", "(c) drift", "(d) frozen"]

# DFA verdicts on each. 0 = passes, 1 = fires (walked back), 2 = not measured.
dfa = np.array([
    [1, 1, 0, 1],   # ResMLP d=256, terminal LN: (a)+(b)+(d) fire, (c) passes (3-seed mean stab 0.16)
    [1, 1, 0, 1],   # ResMLP d=512, terminal LN: same pattern, (c) 3-seed mean 0.16 passes
    [1, 1, 2, 1],   # ViT-Mini: (a)+(b)+(d) fire (acc 0.237 vs frozen ~0.26), (c) not measured
    [1, 0, 2, 1],   # no-LN ResMLP: (a) fires, (b) NEVER, (c) not measured, (d) FIRES (acc 0.327 vs frozen 0.349)
    [1, 0, 2, 2],   # CNN BN: (a) fires, (b) NEVER, (c)+(d) not measured (no CNN frozen baseline)
])

bp = np.zeros_like(dfa)  # BP: passes everywhere

fig, axes = plt.subplots(1, 2, figsize=(11, 3.2))

# 3-color map: blue (pass), red (fire), gray (not measured)
cmap = matplotlib.colors.ListedColormap(["#4682b4", "#cc4444", "#888888"])
norm = matplotlib.colors.BoundaryNorm([-0.5, 0.5, 1.5, 2.5], cmap.N)

for ax, mat, title in [(axes[0], bp, "BP-trained: protocol passes everywhere"),
                       (axes[1], dfa, "DFA-trained: protocol verdict by architecture")]:
    ax.imshow(mat, cmap=cmap, norm=norm, aspect="auto")
    for i in range(mat.shape[0]):
        for j in range(mat.shape[1]):
            v = mat[i, j]
            txt = {0: r"$\checkmark$", 1: "WB", 2: "—"}[v]
            ax.text(j, i, txt, ha="center", va="center",
                    color="white", fontsize=11, fontweight="bold")
    ax.set_xticks(range(len(diags)))
    ax.set_xticklabels(diags, fontsize=9)
    ax.set_yticks(range(len(arches)))
    ax.set_yticklabels(arches, fontsize=8)
    ax.set_title(title, fontsize=10)

# Key finding caption directly below the figure (not floating far below).
fig.text(0.5, -0.05,
         "Key finding: diagnostic (b) BP-grad-floor fires only on terminal-LN architectures. "
         "Of the 5 architectures audited, (b) is restricted to the with-terminal-LN family. "
         "Cells marked — are not applicable for that architecture (no matched frozen-blocks baseline / no stability run).",
         ha="center", fontsize=8, style="italic", wrap=True)

plt.tight_layout()
out = os.path.join(REPO_ROOT, "paper/figures/fig5_cross_arch_summary.pdf")
plt.savefig(out, bbox_inches="tight", dpi=200)
print(f"Saved {out}")