diff options
Diffstat (limited to 'paper/figures/render_fig5_cross_arch.py')
| -rw-r--r-- | paper/figures/render_fig5_cross_arch.py | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/paper/figures/render_fig5_cross_arch.py b/paper/figures/render_fig5_cross_arch.py new file mode 100644 index 0000000..e163956 --- /dev/null +++ b/paper/figures/render_fig5_cross_arch.py @@ -0,0 +1,53 @@ +"""Render Figure 5: cross-architecture verdict matrix.""" +import os +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + +REPO_ROOT = "/home/yurenh2/fa" + +# Verdict matrix: arch x diagnostic +# 0 = ok (BP), 1 = ok-non-LN-arch, 2 = walk-back +# Columns: (a) per-block growth, (b) ||g_L|| floor, (c) drift stability, (d) frozen baseline +# Rows: ResMLP-d256, ResMLP-d512, ViT-Mini, StudentNet (no LN), CNN (BN, no LN) +arches = ["ResMLP $d{=}256$\n(terminal LN)", "ResMLP $d{=}512$\n(terminal LN)", "ViT-Mini\n(cls + LN)", "StudentNet\n(no terminal LN)", "CNN BatchNorm\n(no terminal LN)"] +diags = ["(a) scale", "(b) ${\\|g\\|}$ floor", "(c) drift", "(d) frozen"] + +# DFA verdicts on each +# 1 = fires (walked back); 0 = passes +dfa = np.array([ + [1, 1, 0, 1], # ResMLP d256: (a) fires, (b) fires, (c) noise sub-mode, (d) fires + [1, 1, 0, 1], # ResMLP d512: same pattern + [1, 1, 0, 1], # ViT-Mini: same pattern + [1, 0, 0, 0], # StudentNet: only (a) fires; (b) NEVER + [1, 0, 0, 0], # CNN BN: only (a) fires; (b) NEVER (the killer (b)-is-LN-specific finding) +]) + +bp = np.zeros_like(dfa) # BP: passes everywhere + +fig, axes = plt.subplots(1, 2, figsize=(11, 3.5)) + +for ax, mat, title in [(axes[0], bp, "BP-trained: protocol passes"), + (axes[1], dfa, "DFA-trained: protocol verdict by architecture")]: + cmap = matplotlib.colors.ListedColormap(["#4682b4", "#cc4444"]) + im = ax.imshow(mat, cmap=cmap, aspect="auto", vmin=0, vmax=1) + for i in range(mat.shape[0]): + for j in range(mat.shape[1]): + txt = "WB" if mat[i, j] == 1 else "$\\checkmark$" + ax.text(j, i, txt, ha="center", va="center", color="white", fontsize=11, fontweight="bold") + ax.set_xticks(range(len(diags))) + ax.set_xticklabels(diags, fontsize=9) + ax.set_yticks(range(len(arches))) + ax.set_yticklabels(arches, fontsize=9) + ax.set_title(title, fontsize=10) + +# Highlight the key finding +axes[1].text(0.5, -1.0, "Key finding: diagnostic (b) BP-grad-floor fires only on terminal-LN architectures.\n" + "Across the 5 architecture cases tested, (b) is restricted to the with-terminal-LN family.", + ha="center", fontsize=9, style="italic", transform=axes[1].transAxes) + +plt.tight_layout() +out = os.path.join(REPO_ROOT, "paper/figures/fig5_cross_arch_summary.pdf") +plt.savefig(out, bbox_inches="tight", dpi=200) +print(f"Saved {out}") |
