summaryrefslogtreecommitdiff
path: root/paper/figures/render_fig1_audit_hero.py
diff options
context:
space:
mode:
Diffstat (limited to 'paper/figures/render_fig1_audit_hero.py')
-rw-r--r--paper/figures/render_fig1_audit_hero.py210
1 files changed, 210 insertions, 0 deletions
diff --git a/paper/figures/render_fig1_audit_hero.py b/paper/figures/render_fig1_audit_hero.py
new file mode 100644
index 0000000..24dc8c8
--- /dev/null
+++ b/paper/figures/render_fig1_audit_hero.py
@@ -0,0 +1,210 @@
+"""
+Render Figure 1: Four-panel audit hero figure.
+
+Panel (arch): Architecture diagram (3arc.pdf, merged from external)
+Panel A: Standard pair (accuracy × aggregate Γ) — all 3 methods in the green zone
+Panel B: Per-layer cosine — FA and DFA form an X-cross, BP flat at 1.0
+Panel C: Per-layer ||g_l|| — BP flat, FA gentle decay, DFA cliff
+"""
+import os
+import json
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import matplotlib.patches as mpatches
+from matplotlib.backends.backend_pdf import PdfPages
+import numpy as np
+from PIL import Image
+import subprocess
+
+REPO_ROOT = "/home/yurenh2/fa"
+
+# ─── DATA ────────────────────────────────────────────────────────────────
+
+# Panel A: accuracy and aggregate Γ (3-seed means where available)
+# BP: from protocol_audit (d=256 L=4, 3-seed)
+# FA: from fa_main_audit (d=256 L=4, 3-seed)
+# DFA: from protocol_audit + paper
+
+panel_a = {
+ "BP": {"acc": 0.6147, "gamma": 1.0},
+ "FA": {"acc": 0.401, "gamma": 0.250},
+ "DFA": {"acc": 0.306, "gamma": 0.100},
+}
+
+# Panel B: per-layer cosine (4 blocks, l=0..3)
+# BP: by definition cos(bp_grad, bp_grad) = 1.0
+# FA: 3-seed mean from fa_main_audit
+# DFA: from d=512 L=4 s42 final (pattern robust across d/L)
+panel_b = {
+ "BP": [1.0, 1.0, 1.0, 1.0],
+ "FA": [0.016, 0.072, -0.085, 0.997],
+ "DFA": [0.400, 0.001, -0.0004, -0.002],
+}
+
+# Panel C: per-layer ||g_l|| (5 layers, l=0..4)
+# All from d=256 L=4 s42 (protocol_audit for BP/DFA, fa_main_audit for FA)
+panel_c = {
+ "BP": [4.40e-4, 4.71e-4, 4.79e-4, 4.53e-4, 3.70e-4],
+ "FA": [1.79e-5, 1.21e-6, 8.85e-7, 8.89e-7, 8.89e-7],
+ "DFA": [4.39e-7, 4.19e-9, 4.18e-9, 4.17e-9, 4.17e-9],
+}
+
+CLAMP_EPS = 1e-8 # PyTorch F.cosine_similarity default eps
+
+# ─── STYLE ───────────────────────────────────────────────────────────────
+
+COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"}
+MARKERS = {"BP": "o", "FA": "s", "DFA": "D"}
+plt.rcParams.update({
+ "font.size": 9,
+ "axes.labelsize": 10,
+ "axes.titlesize": 10,
+ "legend.fontsize": 8,
+ "xtick.labelsize": 8,
+ "ytick.labelsize": 8,
+ "font.family": "serif",
+})
+
+fig, axes = plt.subplots(1, 3, figsize=(10.5, 4.5))
+fig.subplots_adjust(wspace=0.38, left=0.06, right=0.97, bottom=0.30, top=0.94)
+
+# ─── PANEL A: Standard pair (x=Γ, y=acc) ────────────────────────────────
+
+ax = axes[0]
+ax.set_title("(A) Standard reporting pair", fontsize=9, fontweight="bold", loc="left")
+
+# Axes: x = aggregate Γ (cosine), y = test accuracy
+x_lim = (-0.22, 1.12)
+y_lim = (-0.02, 0.72)
+
+# Four quadrants: boundaries at x=0 (cos=0) and y=0.10 (chance)
+# Upper-right (green): cos > 0 AND acc > chance
+ax.fill_between([0, x_lim[1]], 0.10, y_lim[1], color="#c8e6c9", alpha=0.5, zorder=0)
+# Lower-left (red): cos < 0 AND acc < chance
+ax.fill_between([x_lim[0], 0], y_lim[0], 0.10, color="#ffcdd2", alpha=0.5, zorder=0)
+# Upper-left (light gray): cos < 0 AND acc > chance
+ax.fill_between([x_lim[0], 0], 0.10, y_lim[1], color="#f5f5f5", alpha=0.6, zorder=0)
+# Lower-right (light gray): cos > 0 AND acc < chance
+ax.fill_between([0, x_lim[1]], y_lim[0], 0.10, color="#f5f5f5", alpha=0.6, zorder=0)
+
+# Quadrant boundary lines
+ax.axvline(0, color="gray", lw=0.6, ls="--", zorder=1)
+ax.axhline(0.10, color="gray", lw=0.6, ls="--", zorder=1)
+
+# Quadrant labels
+ax.text(-0.11, 0.41, "cos < 0,\nacc > chance", fontsize=6.5, color="#888",
+ ha="center", va="center", style="italic", rotation=90)
+ax.text(0.55, 0.04, "cos > 0,\nacc < chance", fontsize=6.5, color="#888",
+ ha="center", va="center", style="italic")
+ax.text(0.55, 0.41, '"looks like\n learning"', fontsize=7, color="#388e3c",
+ ha="center", va="center", fontweight="bold")
+ax.text(-0.11, 0.04, "neither", fontsize=6.5, color="#c62828",
+ ha="center", va="center", style="italic")
+
+# Points: x=gamma, y=acc
+for method in ["BP", "FA", "DFA"]:
+ d = panel_a[method]
+ ax.scatter(d["gamma"], d["acc"], c=COLORS[method], marker=MARKERS[method],
+ s=70, zorder=5, edgecolors="k", linewidths=0.5)
+
+# Labels
+ax.annotate("BP", (panel_a["BP"]["gamma"], panel_a["BP"]["acc"]),
+ xytext=(-8, 8), textcoords="offset points", fontsize=8, fontweight="bold",
+ color=COLORS["BP"])
+ax.annotate("FA", (panel_a["FA"]["gamma"], panel_a["FA"]["acc"]),
+ xytext=(8, 5), textcoords="offset points", fontsize=8, fontweight="bold",
+ color=COLORS["FA"])
+ax.annotate("DFA", (panel_a["DFA"]["gamma"], panel_a["DFA"]["acc"]),
+ xytext=(8, -8), textcoords="offset points", fontsize=8, fontweight="bold",
+ color=COLORS["DFA"])
+
+ax.set_xlabel("Aggregate $\\Gamma$ (cosine)")
+ax.set_ylabel("Test accuracy")
+ax.set_xlim(*x_lim)
+ax.set_ylim(*y_lim)
+
+# ─── PANEL B: Per-layer cosine ──────────────────────────────────────────
+
+ax = axes[1]
+ax.set_title("(B) Per-block cosine $\\cos(a_l, g_l)$", fontsize=9, fontweight="bold", loc="left")
+
+blocks = np.arange(4)
+for method in ["BP", "FA", "DFA"]:
+ vals = panel_b[method]
+ ax.plot(blocks, vals, color=COLORS[method], marker=MARKERS[method],
+ markersize=5, linewidth=1.8, label=method, zorder=3)
+
+ax.axhline(0, color="gray", lw=0.5, ls=":", zorder=1)
+ax.set_xlabel("Block $l$")
+ax.set_ylabel("$\\cos(a_l,\\, \\nabla_{h_l} \\mathcal{L})$")
+ax.set_xticks(blocks)
+ax.set_xticklabels([f"$l={l}$" for l in blocks])
+ax.set_ylim(-0.25, 1.12)
+
+# ─── PANEL C: Per-layer ||g_l|| ─────────────────────────────────────────
+
+ax = axes[2]
+ax.set_title("(C) Per-layer $\\|g_l\\|$ (BP gradient)", fontsize=9, fontweight="bold", loc="left")
+
+layers = np.arange(5)
+for method in ["BP", "FA", "DFA"]:
+ vals = panel_c[method]
+ ax.semilogy(layers, vals, color=COLORS[method], marker=MARKERS[method],
+ markersize=5, linewidth=1.8, label=method, zorder=3)
+
+ax.set_xlabel("Layer $l$")
+ax.set_ylabel("$\\|\\partial\\mathcal{L}/\\partial h_l\\|_2$ (median)")
+ax.set_xticks(layers)
+ax.set_xticklabels([f"$h_{l}$" for l in layers])
+ax.set_ylim(5e-12, 5e-2)
+
+# ─── GRID (all panels) ───────────────────────────────────────────────────
+
+for ax in axes:
+ ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":")
+ ax.set_axisbelow(True)
+# Panel C also needs minor grid for log scale
+axes[2].grid(True, which="minor", color="#e8e8e8", linewidth=0.3, linestyle=":")
+
+# ─── CAPTION BOXES below each panel ─────────────────────────────────────
+
+captions = [
+ "With standard reporting pair, FA and\nDFA reached non-trivial accuracy and\npositive cosine alignment in this setting",
+ "Aggregated cosine lies: shallow layers\nof FA and deep layers of DFA are not\nlearning or aligned well",
+ "Reference also fails: DFA collapses to\nnumerical noise at depth, FA decays\n2 orders of magnitude across layers",
+]
+
+fig.canvas.draw()
+
+box_h = 0.13 # height of caption box in figure coords
+box_gap = 0.12 # gap between axes bottom and box top
+
+for i, (ax, txt) in enumerate(zip(axes, captions)):
+ bbox = ax.get_position()
+ bx0 = bbox.x0
+ bx1 = bbox.x1
+ by_top = bbox.y0 - box_gap
+ by_bot = by_top - box_h
+
+ # Draw rounded rectangle
+ fancy = mpatches.FancyBboxPatch(
+ (bx0, by_bot), bx1 - bx0, box_h,
+ boxstyle="round,pad=0.008",
+ facecolor="#f7f7f7", edgecolor="#aaaaaa", linewidth=0.7,
+ transform=fig.transFigure, clip_on=False)
+ fig.patches.append(fancy)
+
+ # Text centered in the box
+ fig.text((bx0 + bx1) / 2, by_bot + box_h / 2, txt,
+ ha="center", va="center", fontsize=9.5, style="italic",
+ transform=fig.transFigure)
+
+# ─── SAVE ────────────────────────────────────────────────────────────────
+
+out = os.path.join(REPO_ROOT, "paper/figures/fig1_audit_hero.pdf")
+fig.savefig(out, bbox_inches="tight", dpi=300)
+out_png = out.replace(".pdf", ".png")
+fig.savefig(out_png, bbox_inches="tight", dpi=200)
+print(f"Saved: {out}")
+print(f"Saved: {out_png}")