diff options
Diffstat (limited to 'paper/figures/render_fig5_cross_arch.py')
| -rw-r--r-- | paper/figures/render_fig5_cross_arch.py | 69 |
1 files changed, 46 insertions, 23 deletions
diff --git a/paper/figures/render_fig5_cross_arch.py b/paper/figures/render_fig5_cross_arch.py index 9d52e09..470313d 100644 --- a/paper/figures/render_fig5_cross_arch.py +++ b/paper/figures/render_fig5_cross_arch.py @@ -1,4 +1,20 @@ -"""Render Figure 5: cross-architecture verdict matrix.""" +"""Render Figure 5: cross-architecture verdict matrix. + +Verdict encoding: + 0 = passes (✓, blue) + 1 = walked back (WB, red) + 2 = not measured for this architecture (—, gray) + +Sources: + ResMLP d=256 row: results/protocol_audit/audit_table_s42_s123_s456.json + + temporal_evolution_s{42,123,456}.json (no-LN row) + ResMLP d=512 row: results/protocol_audit/audit_d512_3seed.json + ViT-Mini row: snapshot_vit_v1 + ViT walk-back memory (acc 0.237 vs frozen ~0.255-0.261) + no-terminal-LN ResMLP row: results/snapshot_no_outln_v1 (3-seed acc 0.327 ± 0.012 vs + proxy frozen baseline 0.349 ± 0.002 → fails (d) by 2.2 pp) + CNN BatchNorm row: results/protocol_audit/audit_cnn_3seed.json + (no CNN frozen baseline → (c)+(d) not measured) +""" import os import matplotlib matplotlib.use("Agg") @@ -7,45 +23,52 @@ import numpy as np REPO_ROOT = "/home/yurenh2/fa" -# Verdict matrix: arch x diagnostic -# 0 = ok (BP), 1 = ok-non-LN-arch, 2 = walk-back -# Columns: (a) per-block growth, (b) ||g_L|| floor, (c) drift stability, (d) frozen baseline # Rows: ResMLP-d256, ResMLP-d512, ViT-Mini, no-terminal-LN ResMLP-d256, CNN (BN, no LN) -arches = ["ResMLP $d{=}256$\n(terminal LN)", "ResMLP $d{=}512$\n(terminal LN)", "ViT-Mini\n(cls + LN)", "ResMLP $d{=}256$\n(no terminal LN)", "CNN BatchNorm\n(no terminal LN)"] +arches = ["ResMLP $d{=}256$\n(terminal LN)", + "ResMLP $d{=}512$\n(terminal LN)", + "ViT-Mini\n(cls + LN)", + "ResMLP $d{=}256$\n(no terminal LN)", + "CNN BatchNorm\n(no terminal LN)"] diags = ["(a) scale", "(b) ${\\|g\\|}$ floor", "(c) drift", "(d) frozen"] -# DFA verdicts on each -# 1 = fires (walked back); 0 = passes +# DFA verdicts on each. 0 = passes, 1 = fires (walked back), 2 = not measured. dfa = np.array([ - [1, 1, 0, 1], # ResMLP d256: (a) fires, (b) fires, (c) noise sub-mode, (d) fires - [1, 1, 0, 1], # ResMLP d512: same pattern - [1, 1, 0, 1], # ViT-Mini: same pattern - [1, 0, 0, 0], # ResMLP no-LN: only (a) fires; (b) NEVER - [1, 0, 0, 0], # CNN BN: only (a) fires; (b) NEVER (the killer (b)-is-LN-specific finding) + [1, 1, 0, 1], # ResMLP d=256, terminal LN: (a)+(b)+(d) fire, (c) passes (3-seed mean stab 0.16) + [1, 1, 0, 1], # ResMLP d=512, terminal LN: same pattern, (c) 3-seed mean 0.16 passes + [1, 1, 2, 1], # ViT-Mini: (a)+(b)+(d) fire (acc 0.237 vs frozen ~0.26), (c) not measured + [1, 0, 2, 1], # no-LN ResMLP: (a) fires, (b) NEVER, (c) not measured, (d) FIRES (acc 0.327 vs frozen 0.349) + [1, 0, 2, 2], # CNN BN: (a) fires, (b) NEVER, (c)+(d) not measured (no CNN frozen baseline) ]) bp = np.zeros_like(dfa) # BP: passes everywhere -fig, axes = plt.subplots(1, 2, figsize=(11, 4.2)) +fig, axes = plt.subplots(1, 2, figsize=(11, 3.2)) + +# 3-color map: blue (pass), red (fire), gray (not measured) +cmap = matplotlib.colors.ListedColormap(["#4682b4", "#cc4444", "#888888"]) +norm = matplotlib.colors.BoundaryNorm([-0.5, 0.5, 1.5, 2.5], cmap.N) -for ax, mat, title in [(axes[0], bp, "BP-trained: protocol passes"), +for ax, mat, title in [(axes[0], bp, "BP-trained: protocol passes everywhere"), (axes[1], dfa, "DFA-trained: protocol verdict by architecture")]: - cmap = matplotlib.colors.ListedColormap(["#4682b4", "#cc4444"]) - im = ax.imshow(mat, cmap=cmap, aspect="auto", vmin=0, vmax=1) + ax.imshow(mat, cmap=cmap, norm=norm, aspect="auto") for i in range(mat.shape[0]): for j in range(mat.shape[1]): - txt = "WB" if mat[i, j] == 1 else "$\\checkmark$" - ax.text(j, i, txt, ha="center", va="center", color="white", fontsize=11, fontweight="bold") + v = mat[i, j] + txt = {0: r"$\checkmark$", 1: "WB", 2: "—"}[v] + ax.text(j, i, txt, ha="center", va="center", + color="white", fontsize=11, fontweight="bold") ax.set_xticks(range(len(diags))) ax.set_xticklabels(diags, fontsize=9) ax.set_yticks(range(len(arches))) - ax.set_yticklabels(arches, fontsize=9) + ax.set_yticklabels(arches, fontsize=8) ax.set_title(title, fontsize=10) -# Highlight the key finding — place well below the multiline x/y tick labels to avoid overlap -axes[1].text(0.5, -1.55, "Key finding: diagnostic (b) BP-grad-floor fires only on terminal-LN architectures.\n" - "Across the 5 architecture cases tested, (b) is restricted to the with-terminal-LN family.", - ha="center", fontsize=9, style="italic", transform=axes[1].transAxes) +# Key finding caption directly below the figure (not floating far below). +fig.text(0.5, -0.05, + "Key finding: diagnostic (b) BP-grad-floor fires only on terminal-LN architectures. " + "Of the 5 architectures audited, (b) is restricted to the with-terminal-LN family. " + "Cells marked — are not applicable for that architecture (no matched frozen-blocks baseline / no stability run).", + ha="center", fontsize=8, style="italic", wrap=True) plt.tight_layout() out = os.path.join(REPO_ROOT, "paper/figures/fig5_cross_arch_summary.pdf") |
