diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 04:46:59 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 04:46:59 -0500 |
| commit | 07b10f06478514bbe9d9c77461a90f9d3254218b (patch) | |
| tree | 4f559a8131159e47da6ffe1666207eba96b02688 /paper/figures | |
| parent | 58259e151858a545e359c2134b1db84bee3a4be6 (diff) | |
Fill in tables 1-3 + generate figures 2/4/5 from existing data
Tables filled with real values:
Table 1: 5-method audit (3-seed mean ± std for acc, headline Γ, verdict)
Table 2: 4-condition mode 2 validation (cos and ρ values from existing
checkpoint measurements)
Table 3: protocol thresholds (50×, 1e-7, 0.30, 2pp)
Figures generated from existing data:
fig2_decision_utility.pdf: 5×7 verdict heatmap from
results/protocol_audit/ablation_decision_utility.json
fig4_penalty_rescue.pdf: 3-panel — trajectory + cos/ρ bars + 2×2 acc
from snapshot_evolution_v2 + dfa_residual_penalty + bp_with_penalty
fig5_cross_arch_summary.pdf: 5×4 BP/DFA verdict matrix across
architectures
Compiles to 8 pages with all tables/figures rendered. §1-§7 main body
still has only paragraph topic sentences (TODO: per-section prose
filling via codex). Figure numbering is wrong (codex put figures in
section order not numerical order — need fixing).
Diffstat (limited to 'paper/figures')
| -rw-r--r-- | paper/figures/fig2_decision_utility.pdf | bin | 10529 -> 21132 bytes | |||
| -rw-r--r-- | paper/figures/fig4_penalty_rescue.pdf | bin | 10698 -> 32527 bytes | |||
| -rw-r--r-- | paper/figures/fig5_cross_arch_summary.pdf | bin | 11836 -> 31765 bytes | |||
| -rw-r--r-- | paper/figures/render_fig2_decision_utility.py | 56 | ||||
| -rw-r--r-- | paper/figures/render_fig4_penalty_rescue.py | 92 | ||||
| -rw-r--r-- | paper/figures/render_fig5_cross_arch.py | 53 |
6 files changed, 201 insertions, 0 deletions
diff --git a/paper/figures/fig2_decision_utility.pdf b/paper/figures/fig2_decision_utility.pdf Binary files differindex fae5c98..e37004a 100644 --- a/paper/figures/fig2_decision_utility.pdf +++ b/paper/figures/fig2_decision_utility.pdf diff --git a/paper/figures/fig4_penalty_rescue.pdf b/paper/figures/fig4_penalty_rescue.pdf Binary files differindex 07f9b95..0dcdce7 100644 --- a/paper/figures/fig4_penalty_rescue.pdf +++ b/paper/figures/fig4_penalty_rescue.pdf diff --git a/paper/figures/fig5_cross_arch_summary.pdf b/paper/figures/fig5_cross_arch_summary.pdf Binary files differindex c5ac3fc..93c0676 100644 --- a/paper/figures/fig5_cross_arch_summary.pdf +++ b/paper/figures/fig5_cross_arch_summary.pdf diff --git a/paper/figures/render_fig2_decision_utility.py b/paper/figures/render_fig2_decision_utility.py new file mode 100644 index 0000000..ba1a148 --- /dev/null +++ b/paper/figures/render_fig2_decision_utility.py @@ -0,0 +1,56 @@ +"""Render Figure 2: decision-utility ablation as a heatmap.""" +import os +import json +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + +REPO_ROOT = "/home/yurenh2/fa" +with open(os.path.join(REPO_ROOT, "results/protocol_audit/ablation_decision_utility.json")) as f: + data = json.load(f) + +methods = ["bp", "ep", "dfa", "credit_bridge", "state_bridge"] +method_labels = ["BP", "EP", "DFA", "Credit Bridge", "State Bridge"] +strategies = ["S0", "S1", "S2", "S3", "S4", "S5", "S_full"] +strategy_labels = ["S0\nacc only", "S1\n+$\\Gamma$\n(field std)", "S2\n+(a)\nscale", "S3\n+(b)\nfloor", "S4\n+(c)\nstability", "S5\n+(d)\nfrozen", "S$_\\mathrm{full}$\nfull"] + +# Build a verdict matrix: 1 = walk back (red), 0 = trustworthy (blue) +mat = np.zeros((len(methods), len(strategies))) +for i, m in enumerate(methods): + for j, s in enumerate(strategies): + v = data["table"][m][s] + if "WALK" in v: + mat[i, j] = 1.0 + +fig, ax = plt.subplots(figsize=(7, 3.2)) +cmap = matplotlib.colors.ListedColormap(["#4682b4", "#cc4444"]) +im = ax.imshow(mat, cmap=cmap, aspect="auto", vmin=0, vmax=1) + +# Annotate each cell +for i in range(len(methods)): + for j in range(len(strategies)): + v = data["table"][methods[i]][strategies[j]] + txt = "WB" if "WALK" in v else "$\\checkmark$" + ax.text(j, i, txt, ha="center", va="center", + color="white" if mat[i, j] > 0.5 else "white", fontsize=10, fontweight="bold") + +ax.set_xticks(range(len(strategies))) +ax.set_xticklabels(strategy_labels, fontsize=8) +ax.set_yticks(range(len(methods))) +ax.set_yticklabels(method_labels, fontsize=10) +ax.set_xlabel("reporting strategy", fontsize=10) + +# Color legend +from matplotlib.patches import Patch +legend_elements = [ + Patch(facecolor="#4682b4", label="$\\checkmark$ trustworthy"), + Patch(facecolor="#cc4444", label="WB walked back"), +] +ax.legend(handles=legend_elements, loc="upper center", bbox_to_anchor=(0.5, -0.25), ncol=2, fontsize=9, frameon=False) + +ax.set_title("Decision-utility ablation: 7 reporting strategies vs 5 methods\nfield-standard pair (S0, S1) walks back 0/5; full protocol walks back 3/5", fontsize=10) +plt.tight_layout() +out = os.path.join(REPO_ROOT, "paper/figures/fig2_decision_utility.pdf") +plt.savefig(out, bbox_inches="tight", dpi=200) +print(f"Saved {out}") diff --git a/paper/figures/render_fig4_penalty_rescue.py b/paper/figures/render_fig4_penalty_rescue.py new file mode 100644 index 0000000..b7089ec --- /dev/null +++ b/paper/figures/render_fig4_penalty_rescue.py @@ -0,0 +1,92 @@ +"""Render Figure 4: penalty rescue + capacity-cost control.""" +import os +import json +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + +REPO_ROOT = "/home/yurenh2/fa" + +# Panel A: penalty rescue trajectory +with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f: + snap = json.load(f) +vanilla = snap["dfa_log"] +ep_vanilla = [e["epoch"] for e in vanilla] +hL_vanilla = [e["hidden_norms"][-1] for e in vanilla] +g_vanilla = [e["bp_grad_norms_per_sample_med"][-1] for e in vanilla] + +with open(os.path.join(REPO_ROOT, "results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json")) as f: + pen = json.load(f) +ep_pen = [e["epoch"] for e in pen["log"]] +hL_pen = [e["h_L_norm"] for e in pen["log"]] +g_pen = [e["g_2_norm"] for e in pen["log"]] + +# Panel B: cosine + rho across vanilla / penalized / fresh-B / penalty lam=1e-4 +# Read from existing results +conditions = ["vanilla\nDFA\n(early)", "penalized\n$\\lambda{=}10^{-4}$", "penalized\n$\\lambda{=}10^{-2}$", "fresh-$B$\nnull", "BP grad\n(positive)"] +deep_cos = [-0.008, -0.022, +0.155, +0.002, +1.000] +deep_rho = [-0.003, -0.004, +0.080, +0.006, +0.997] +cos_err = [0.013, 0.0, 0.025, 0.022, 0.0] +rho_err = [0.005, 0.0, 0.011, 0.0, 0.0] + +# Panel C: 2x2 capacity-cost control +methods = ["BP", "DFA"] +no_pen = [0.609, 0.308] +with_pen = [0.530, 0.363] +shallow = 0.349 + +fig, axes = plt.subplots(1, 3, figsize=(13, 3.5)) + +# Panel A: trajectory +ax = axes[0] +ax.plot(ep_vanilla, hL_vanilla, label="vanilla DFA $\\|h_L\\|$", color="C3", lw=1.5, marker="o", markersize=3) +ax.plot(ep_pen, hL_pen, label="penalized DFA $\\|h_L\\|$ ($\\lambda{=}10^{-2}$)", color="C2", lw=1.5, marker="s", markersize=3) +ax.set_yscale("log") +ax.set_xlabel("epoch", fontsize=10) +ax.set_ylabel("$\\|h_L\\|$ (log)", fontsize=10) +ax.set_title("(a) penalty contains residual stream\n(4 OOM rescue)", fontsize=10) +ax.legend(loc="lower right", fontsize=8) +ax.grid(True, alpha=0.3, which="both") +ax2 = ax.twinx() +ax2.plot(ep_vanilla, g_vanilla, label="vanilla $\\|g\\|$", color="C3", lw=1, ls=":", marker="^", markersize=3) +ax2.plot(ep_pen, g_pen, label="penalized $\\|g\\|$", color="C2", lw=1, ls=":", marker="v", markersize=3) +ax2.axhline(1e-7, color="black", ls="--", lw=0.8, label="$10^{-7}$ floor") +ax2.set_yscale("log") +ax2.set_ylabel("$\\|g_L\\|$ (log)", fontsize=9, color="gray") +ax2.tick_params(axis="y", labelcolor="gray") + +# Panel B: cosine + rho +ax = axes[1] +xpos = np.arange(len(conditions)) +w = 0.35 +b1 = ax.bar(xpos - w/2, deep_cos, w, yerr=cos_err, label="deep cos", color="#4682b4", capsize=3) +b2 = ax.bar(xpos + w/2, deep_rho, w, yerr=rho_err, label="deep $\\rho$", color="#7da76f", capsize=3) +ax.axhline(0, color="black", lw=0.5) +ax.set_xticks(xpos) +ax.set_xticklabels(conditions, fontsize=8) +ax.set_ylabel("deep-layer alignment", fontsize=10) +ax.set_title("(b) two metrics agree across conditions\n(measurement vs random feedback)", fontsize=10) +ax.legend(loc="upper left", fontsize=8) +ax.grid(True, axis="y", alpha=0.3) +ax.set_ylim(-0.1, 1.1) + +# Panel C: 2x2 capacity-cost +ax = axes[2] +xpos = np.arange(len(methods)) +w = 0.35 +ax.bar(xpos - w/2, no_pen, w, label="no penalty", color="#4682b4") +ax.bar(xpos + w/2, with_pen, w, label="with penalty $\\lambda{=}10^{-2}$", color="#cc4444") +ax.axhline(shallow, color="black", ls="--", lw=1, label=f"frozen baseline {shallow}") +ax.set_xticks(xpos) +ax.set_xticklabels(methods, fontsize=10) +ax.set_ylabel("test accuracy", fontsize=10) +ax.set_title("(c) BP+penalty 2$\\times$2 control\n(BP-pen-cost $-8$pp; gap $17$pp $=$ credit quality)", fontsize=10) +ax.legend(loc="upper right", fontsize=8) +ax.grid(True, axis="y", alpha=0.3) +ax.set_ylim(0, 0.7) + +plt.tight_layout() +out = os.path.join(REPO_ROOT, "paper/figures/fig4_penalty_rescue.pdf") +plt.savefig(out, bbox_inches="tight", dpi=200) +print(f"Saved {out}") diff --git a/paper/figures/render_fig5_cross_arch.py b/paper/figures/render_fig5_cross_arch.py new file mode 100644 index 0000000..e163956 --- /dev/null +++ b/paper/figures/render_fig5_cross_arch.py @@ -0,0 +1,53 @@ +"""Render Figure 5: cross-architecture verdict matrix.""" +import os +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + +REPO_ROOT = "/home/yurenh2/fa" + +# Verdict matrix: arch x diagnostic +# 0 = ok (BP), 1 = ok-non-LN-arch, 2 = walk-back +# Columns: (a) per-block growth, (b) ||g_L|| floor, (c) drift stability, (d) frozen baseline +# Rows: ResMLP-d256, ResMLP-d512, ViT-Mini, StudentNet (no LN), CNN (BN, no LN) +arches = ["ResMLP $d{=}256$\n(terminal LN)", "ResMLP $d{=}512$\n(terminal LN)", "ViT-Mini\n(cls + LN)", "StudentNet\n(no terminal LN)", "CNN BatchNorm\n(no terminal LN)"] +diags = ["(a) scale", "(b) ${\\|g\\|}$ floor", "(c) drift", "(d) frozen"] + +# DFA verdicts on each +# 1 = fires (walked back); 0 = passes +dfa = np.array([ + [1, 1, 0, 1], # ResMLP d256: (a) fires, (b) fires, (c) noise sub-mode, (d) fires + [1, 1, 0, 1], # ResMLP d512: same pattern + [1, 1, 0, 1], # ViT-Mini: same pattern + [1, 0, 0, 0], # StudentNet: only (a) fires; (b) NEVER + [1, 0, 0, 0], # CNN BN: only (a) fires; (b) NEVER (the killer (b)-is-LN-specific finding) +]) + +bp = np.zeros_like(dfa) # BP: passes everywhere + +fig, axes = plt.subplots(1, 2, figsize=(11, 3.5)) + +for ax, mat, title in [(axes[0], bp, "BP-trained: protocol passes"), + (axes[1], dfa, "DFA-trained: protocol verdict by architecture")]: + cmap = matplotlib.colors.ListedColormap(["#4682b4", "#cc4444"]) + im = ax.imshow(mat, cmap=cmap, aspect="auto", vmin=0, vmax=1) + for i in range(mat.shape[0]): + for j in range(mat.shape[1]): + txt = "WB" if mat[i, j] == 1 else "$\\checkmark$" + ax.text(j, i, txt, ha="center", va="center", color="white", fontsize=11, fontweight="bold") + ax.set_xticks(range(len(diags))) + ax.set_xticklabels(diags, fontsize=9) + ax.set_yticks(range(len(arches))) + ax.set_yticklabels(arches, fontsize=9) + ax.set_title(title, fontsize=10) + +# Highlight the key finding +axes[1].text(0.5, -1.0, "Key finding: diagnostic (b) BP-grad-floor fires only on terminal-LN architectures.\n" + "Across the 5 architecture cases tested, (b) is restricted to the with-terminal-LN family.", + ha="center", fontsize=9, style="italic", transform=axes[1].transAxes) + +plt.tight_layout() +out = os.path.join(REPO_ROOT, "paper/figures/fig5_cross_arch_summary.pdf") +plt.savefig(out, bbox_inches="tight", dpi=200) +print(f"Saved {out}") |
