summaryrefslogtreecommitdiff
path: root/paper/figures/render_fig5_cross_arch.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 04:46:59 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 04:46:59 -0500
commit07b10f06478514bbe9d9c77461a90f9d3254218b (patch)
tree4f559a8131159e47da6ffe1666207eba96b02688 /paper/figures/render_fig5_cross_arch.py
parent58259e151858a545e359c2134b1db84bee3a4be6 (diff)
Fill in tables 1-3 + generate figures 2/4/5 from existing data
Tables filled with real values: Table 1: 5-method audit (3-seed mean ± std for acc, headline Γ, verdict) Table 2: 4-condition mode 2 validation (cos and ρ values from existing checkpoint measurements) Table 3: protocol thresholds (50×, 1e-7, 0.30, 2pp) Figures generated from existing data: fig2_decision_utility.pdf: 5×7 verdict heatmap from results/protocol_audit/ablation_decision_utility.json fig4_penalty_rescue.pdf: 3-panel — trajectory + cos/ρ bars + 2×2 acc from snapshot_evolution_v2 + dfa_residual_penalty + bp_with_penalty fig5_cross_arch_summary.pdf: 5×4 BP/DFA verdict matrix across architectures Compiles to 8 pages with all tables/figures rendered. §1-§7 main body still has only paragraph topic sentences (TODO: per-section prose filling via codex). Figure numbering is wrong (codex put figures in section order not numerical order — need fixing).
Diffstat (limited to 'paper/figures/render_fig5_cross_arch.py')
-rw-r--r--paper/figures/render_fig5_cross_arch.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/paper/figures/render_fig5_cross_arch.py b/paper/figures/render_fig5_cross_arch.py
new file mode 100644
index 0000000..e163956
--- /dev/null
+++ b/paper/figures/render_fig5_cross_arch.py
@@ -0,0 +1,53 @@
+"""Render Figure 5: cross-architecture verdict matrix."""
+import os
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import numpy as np
+
+REPO_ROOT = "/home/yurenh2/fa"
+
+# Verdict matrix: arch x diagnostic
+# 0 = ok (BP), 1 = ok-non-LN-arch, 2 = walk-back
+# Columns: (a) per-block growth, (b) ||g_L|| floor, (c) drift stability, (d) frozen baseline
+# Rows: ResMLP-d256, ResMLP-d512, ViT-Mini, StudentNet (no LN), CNN (BN, no LN)
+arches = ["ResMLP $d{=}256$\n(terminal LN)", "ResMLP $d{=}512$\n(terminal LN)", "ViT-Mini\n(cls + LN)", "StudentNet\n(no terminal LN)", "CNN BatchNorm\n(no terminal LN)"]
+diags = ["(a) scale", "(b) ${\\|g\\|}$ floor", "(c) drift", "(d) frozen"]
+
+# DFA verdicts on each
+# 1 = fires (walked back); 0 = passes
+dfa = np.array([
+ [1, 1, 0, 1], # ResMLP d256: (a) fires, (b) fires, (c) noise sub-mode, (d) fires
+ [1, 1, 0, 1], # ResMLP d512: same pattern
+ [1, 1, 0, 1], # ViT-Mini: same pattern
+ [1, 0, 0, 0], # StudentNet: only (a) fires; (b) NEVER
+ [1, 0, 0, 0], # CNN BN: only (a) fires; (b) NEVER (the killer (b)-is-LN-specific finding)
+])
+
+bp = np.zeros_like(dfa) # BP: passes everywhere
+
+fig, axes = plt.subplots(1, 2, figsize=(11, 3.5))
+
+for ax, mat, title in [(axes[0], bp, "BP-trained: protocol passes"),
+ (axes[1], dfa, "DFA-trained: protocol verdict by architecture")]:
+ cmap = matplotlib.colors.ListedColormap(["#4682b4", "#cc4444"])
+ im = ax.imshow(mat, cmap=cmap, aspect="auto", vmin=0, vmax=1)
+ for i in range(mat.shape[0]):
+ for j in range(mat.shape[1]):
+ txt = "WB" if mat[i, j] == 1 else "$\\checkmark$"
+ ax.text(j, i, txt, ha="center", va="center", color="white", fontsize=11, fontweight="bold")
+ ax.set_xticks(range(len(diags)))
+ ax.set_xticklabels(diags, fontsize=9)
+ ax.set_yticks(range(len(arches)))
+ ax.set_yticklabels(arches, fontsize=9)
+ ax.set_title(title, fontsize=10)
+
+# Highlight the key finding
+axes[1].text(0.5, -1.0, "Key finding: diagnostic (b) BP-grad-floor fires only on terminal-LN architectures.\n"
+ "Across the 5 architecture cases tested, (b) is restricted to the with-terminal-LN family.",
+ ha="center", fontsize=9, style="italic", transform=axes[1].transAxes)
+
+plt.tight_layout()
+out = os.path.join(REPO_ROOT, "paper/figures/fig5_cross_arch_summary.pdf")
+plt.savefig(out, bbox_inches="tight", dpi=200)
+print(f"Saved {out}")