diff options
Diffstat (limited to 'paper/figures/render_fig1_audit_hero.py')
| -rw-r--r-- | paper/figures/render_fig1_audit_hero.py | 210 |
1 files changed, 210 insertions, 0 deletions
diff --git a/paper/figures/render_fig1_audit_hero.py b/paper/figures/render_fig1_audit_hero.py new file mode 100644 index 0000000..24dc8c8 --- /dev/null +++ b/paper/figures/render_fig1_audit_hero.py @@ -0,0 +1,210 @@ +""" +Render Figure 1: Four-panel audit hero figure. + +Panel (arch): Architecture diagram (3arc.pdf, merged from external) +Panel A: Standard pair (accuracy × aggregate Γ) — all 3 methods in the green zone +Panel B: Per-layer cosine — FA and DFA form an X-cross, BP flat at 1.0 +Panel C: Per-layer ||g_l|| — BP flat, FA gentle decay, DFA cliff +""" +import os +import json +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.backends.backend_pdf import PdfPages +import numpy as np +from PIL import Image +import subprocess + +REPO_ROOT = "/home/yurenh2/fa" + +# ─── DATA ──────────────────────────────────────────────────────────────── + +# Panel A: accuracy and aggregate Γ (3-seed means where available) +# BP: from protocol_audit (d=256 L=4, 3-seed) +# FA: from fa_main_audit (d=256 L=4, 3-seed) +# DFA: from protocol_audit + paper + +panel_a = { + "BP": {"acc": 0.6147, "gamma": 1.0}, + "FA": {"acc": 0.401, "gamma": 0.250}, + "DFA": {"acc": 0.306, "gamma": 0.100}, +} + +# Panel B: per-layer cosine (4 blocks, l=0..3) +# BP: by definition cos(bp_grad, bp_grad) = 1.0 +# FA: 3-seed mean from fa_main_audit +# DFA: from d=512 L=4 s42 final (pattern robust across d/L) +panel_b = { + "BP": [1.0, 1.0, 1.0, 1.0], + "FA": [0.016, 0.072, -0.085, 0.997], + "DFA": [0.400, 0.001, -0.0004, -0.002], +} + +# Panel C: per-layer ||g_l|| (5 layers, l=0..4) +# All from d=256 L=4 s42 (protocol_audit for BP/DFA, fa_main_audit for FA) +panel_c = { + "BP": [4.40e-4, 4.71e-4, 4.79e-4, 4.53e-4, 3.70e-4], + "FA": [1.79e-5, 1.21e-6, 8.85e-7, 8.89e-7, 8.89e-7], + "DFA": [4.39e-7, 4.19e-9, 4.18e-9, 4.17e-9, 4.17e-9], +} + +CLAMP_EPS = 1e-8 # PyTorch F.cosine_similarity default eps + +# ─── STYLE ─────────────────────────────────────────────────────────────── + +COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"} +MARKERS = {"BP": "o", "FA": "s", "DFA": "D"} +plt.rcParams.update({ + "font.size": 9, + "axes.labelsize": 10, + "axes.titlesize": 10, + "legend.fontsize": 8, + "xtick.labelsize": 8, + "ytick.labelsize": 8, + "font.family": "serif", +}) + +fig, axes = plt.subplots(1, 3, figsize=(10.5, 4.5)) +fig.subplots_adjust(wspace=0.38, left=0.06, right=0.97, bottom=0.30, top=0.94) + +# ─── PANEL A: Standard pair (x=Γ, y=acc) ──────────────────────────────── + +ax = axes[0] +ax.set_title("(A) Standard reporting pair", fontsize=9, fontweight="bold", loc="left") + +# Axes: x = aggregate Γ (cosine), y = test accuracy +x_lim = (-0.22, 1.12) +y_lim = (-0.02, 0.72) + +# Four quadrants: boundaries at x=0 (cos=0) and y=0.10 (chance) +# Upper-right (green): cos > 0 AND acc > chance +ax.fill_between([0, x_lim[1]], 0.10, y_lim[1], color="#c8e6c9", alpha=0.5, zorder=0) +# Lower-left (red): cos < 0 AND acc < chance +ax.fill_between([x_lim[0], 0], y_lim[0], 0.10, color="#ffcdd2", alpha=0.5, zorder=0) +# Upper-left (light gray): cos < 0 AND acc > chance +ax.fill_between([x_lim[0], 0], 0.10, y_lim[1], color="#f5f5f5", alpha=0.6, zorder=0) +# Lower-right (light gray): cos > 0 AND acc < chance +ax.fill_between([0, x_lim[1]], y_lim[0], 0.10, color="#f5f5f5", alpha=0.6, zorder=0) + +# Quadrant boundary lines +ax.axvline(0, color="gray", lw=0.6, ls="--", zorder=1) +ax.axhline(0.10, color="gray", lw=0.6, ls="--", zorder=1) + +# Quadrant labels +ax.text(-0.11, 0.41, "cos < 0,\nacc > chance", fontsize=6.5, color="#888", + ha="center", va="center", style="italic", rotation=90) +ax.text(0.55, 0.04, "cos > 0,\nacc < chance", fontsize=6.5, color="#888", + ha="center", va="center", style="italic") +ax.text(0.55, 0.41, '"looks like\n learning"', fontsize=7, color="#388e3c", + ha="center", va="center", fontweight="bold") +ax.text(-0.11, 0.04, "neither", fontsize=6.5, color="#c62828", + ha="center", va="center", style="italic") + +# Points: x=gamma, y=acc +for method in ["BP", "FA", "DFA"]: + d = panel_a[method] + ax.scatter(d["gamma"], d["acc"], c=COLORS[method], marker=MARKERS[method], + s=70, zorder=5, edgecolors="k", linewidths=0.5) + +# Labels +ax.annotate("BP", (panel_a["BP"]["gamma"], panel_a["BP"]["acc"]), + xytext=(-8, 8), textcoords="offset points", fontsize=8, fontweight="bold", + color=COLORS["BP"]) +ax.annotate("FA", (panel_a["FA"]["gamma"], panel_a["FA"]["acc"]), + xytext=(8, 5), textcoords="offset points", fontsize=8, fontweight="bold", + color=COLORS["FA"]) +ax.annotate("DFA", (panel_a["DFA"]["gamma"], panel_a["DFA"]["acc"]), + xytext=(8, -8), textcoords="offset points", fontsize=8, fontweight="bold", + color=COLORS["DFA"]) + +ax.set_xlabel("Aggregate $\\Gamma$ (cosine)") +ax.set_ylabel("Test accuracy") +ax.set_xlim(*x_lim) +ax.set_ylim(*y_lim) + +# ─── PANEL B: Per-layer cosine ────────────────────────────────────────── + +ax = axes[1] +ax.set_title("(B) Per-block cosine $\\cos(a_l, g_l)$", fontsize=9, fontweight="bold", loc="left") + +blocks = np.arange(4) +for method in ["BP", "FA", "DFA"]: + vals = panel_b[method] + ax.plot(blocks, vals, color=COLORS[method], marker=MARKERS[method], + markersize=5, linewidth=1.8, label=method, zorder=3) + +ax.axhline(0, color="gray", lw=0.5, ls=":", zorder=1) +ax.set_xlabel("Block $l$") +ax.set_ylabel("$\\cos(a_l,\\, \\nabla_{h_l} \\mathcal{L})$") +ax.set_xticks(blocks) +ax.set_xticklabels([f"$l={l}$" for l in blocks]) +ax.set_ylim(-0.25, 1.12) + +# ─── PANEL C: Per-layer ||g_l|| ───────────────────────────────────────── + +ax = axes[2] +ax.set_title("(C) Per-layer $\\|g_l\\|$ (BP gradient)", fontsize=9, fontweight="bold", loc="left") + +layers = np.arange(5) +for method in ["BP", "FA", "DFA"]: + vals = panel_c[method] + ax.semilogy(layers, vals, color=COLORS[method], marker=MARKERS[method], + markersize=5, linewidth=1.8, label=method, zorder=3) + +ax.set_xlabel("Layer $l$") +ax.set_ylabel("$\\|\\partial\\mathcal{L}/\\partial h_l\\|_2$ (median)") +ax.set_xticks(layers) +ax.set_xticklabels([f"$h_{l}$" for l in layers]) +ax.set_ylim(5e-12, 5e-2) + +# ─── GRID (all panels) ─────────────────────────────────────────────────── + +for ax in axes: + ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":") + ax.set_axisbelow(True) +# Panel C also needs minor grid for log scale +axes[2].grid(True, which="minor", color="#e8e8e8", linewidth=0.3, linestyle=":") + +# ─── CAPTION BOXES below each panel ───────────────────────────────────── + +captions = [ + "With standard reporting pair, FA and\nDFA reached non-trivial accuracy and\npositive cosine alignment in this setting", + "Aggregated cosine lies: shallow layers\nof FA and deep layers of DFA are not\nlearning or aligned well", + "Reference also fails: DFA collapses to\nnumerical noise at depth, FA decays\n2 orders of magnitude across layers", +] + +fig.canvas.draw() + +box_h = 0.13 # height of caption box in figure coords +box_gap = 0.12 # gap between axes bottom and box top + +for i, (ax, txt) in enumerate(zip(axes, captions)): + bbox = ax.get_position() + bx0 = bbox.x0 + bx1 = bbox.x1 + by_top = bbox.y0 - box_gap + by_bot = by_top - box_h + + # Draw rounded rectangle + fancy = mpatches.FancyBboxPatch( + (bx0, by_bot), bx1 - bx0, box_h, + boxstyle="round,pad=0.008", + facecolor="#f7f7f7", edgecolor="#aaaaaa", linewidth=0.7, + transform=fig.transFigure, clip_on=False) + fig.patches.append(fancy) + + # Text centered in the box + fig.text((bx0 + bx1) / 2, by_bot + box_h / 2, txt, + ha="center", va="center", fontsize=9.5, style="italic", + transform=fig.transFigure) + +# ─── SAVE ──────────────────────────────────────────────────────────────── + +out = os.path.join(REPO_ROOT, "paper/figures/fig1_audit_hero.pdf") +fig.savefig(out, bbox_inches="tight", dpi=300) +out_png = out.replace(".pdf", ".png") +fig.savefig(out_png, bbox_inches="tight", dpi=200) +print(f"Saved: {out}") +print(f"Saved: {out_png}") |
