""" Render Figure 1: Four-panel audit hero figure. Panel (arch): Architecture diagram (3arc.pdf, merged from external) Panel A: Standard pair (accuracy × aggregate Γ) — all 3 methods in the green zone Panel B: Per-layer cosine — FA and DFA form an X-cross, BP flat at 1.0 Panel C: Per-layer ||g_l|| — BP flat, FA gentle decay, DFA cliff """ import os import json import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.backends.backend_pdf import PdfPages import numpy as np from PIL import Image import subprocess REPO_ROOT = "/home/yurenh2/fa" # ─── DATA ──────────────────────────────────────────────────────────────── # Panel A: accuracy and aggregate Γ (3-seed means where available) # BP: from protocol_audit (d=256 L=4, 3-seed) # FA: from fa_main_audit (d=256 L=4, 3-seed) # DFA: from protocol_audit + paper panel_a = { "BP": {"acc": 0.6147, "gamma": 1.0}, "FA": {"acc": 0.401, "gamma": 0.250}, "DFA": {"acc": 0.306, "gamma": 0.100}, } # Panel B: per-layer cosine (4 blocks, l=0..3) # BP: by definition cos(bp_grad, bp_grad) = 1.0 # FA: 3-seed mean from fa_main_audit # DFA: from d=512 L=4 s42 final (pattern robust across d/L) panel_b = { "BP": [1.0, 1.0, 1.0, 1.0], "FA": [0.016, 0.072, -0.085, 0.997], "DFA": [0.400, 0.001, -0.0004, -0.002], } # Panel C: per-layer ||g_l|| (5 layers, l=0..4) # All from d=256 L=4 s42 (protocol_audit for BP/DFA, fa_main_audit for FA) panel_c = { "BP": [4.40e-4, 4.71e-4, 4.79e-4, 4.53e-4, 3.70e-4], "FA": [1.79e-5, 1.21e-6, 8.85e-7, 8.89e-7, 8.89e-7], "DFA": [4.39e-7, 4.19e-9, 4.18e-9, 4.17e-9, 4.17e-9], } CLAMP_EPS = 1e-8 # PyTorch F.cosine_similarity default eps # ─── STYLE ─────────────────────────────────────────────────────────────── COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"} MARKERS = {"BP": "o", "FA": "s", "DFA": "D"} plt.rcParams.update({ "font.size": 9, "axes.labelsize": 10, "axes.titlesize": 10, "legend.fontsize": 8, "xtick.labelsize": 8, "ytick.labelsize": 8, "font.family": "serif", }) fig, axes = plt.subplots(1, 3, figsize=(10.5, 4.5)) fig.subplots_adjust(wspace=0.38, left=0.06, right=0.97, bottom=0.30, top=0.94) # ─── PANEL A: Standard pair (x=Γ, y=acc) ──────────────────────────────── ax = axes[0] ax.set_title("(A) Standard reporting pair", fontsize=9, fontweight="bold", loc="left") # Axes: x = aggregate Γ (cosine), y = test accuracy x_lim = (-0.22, 1.12) y_lim = (-0.02, 0.72) # Four quadrants: boundaries at x=0 (cos=0) and y=0.10 (chance) # Upper-right (green): cos > 0 AND acc > chance ax.fill_between([0, x_lim[1]], 0.10, y_lim[1], color="#c8e6c9", alpha=0.5, zorder=0) # Lower-left (red): cos < 0 AND acc < chance ax.fill_between([x_lim[0], 0], y_lim[0], 0.10, color="#ffcdd2", alpha=0.5, zorder=0) # Upper-left (light gray): cos < 0 AND acc > chance ax.fill_between([x_lim[0], 0], 0.10, y_lim[1], color="#f5f5f5", alpha=0.6, zorder=0) # Lower-right (light gray): cos > 0 AND acc < chance ax.fill_between([0, x_lim[1]], y_lim[0], 0.10, color="#f5f5f5", alpha=0.6, zorder=0) # Quadrant boundary lines ax.axvline(0, color="gray", lw=0.6, ls="--", zorder=1) ax.axhline(0.10, color="gray", lw=0.6, ls="--", zorder=1) # Quadrant labels ax.text(-0.11, 0.41, "cos < 0,\nacc > chance", fontsize=6.5, color="#888", ha="center", va="center", style="italic", rotation=90) ax.text(0.55, 0.04, "cos > 0,\nacc < chance", fontsize=6.5, color="#888", ha="center", va="center", style="italic") ax.text(0.55, 0.41, '"looks like\n learning"', fontsize=7, color="#388e3c", ha="center", va="center", fontweight="bold") ax.text(-0.11, 0.04, "neither", fontsize=6.5, color="#c62828", ha="center", va="center", style="italic") # Points: x=gamma, y=acc for method in ["BP", "FA", "DFA"]: d = panel_a[method] ax.scatter(d["gamma"], d["acc"], c=COLORS[method], marker=MARKERS[method], s=70, zorder=5, edgecolors="k", linewidths=0.5) # Labels ax.annotate("BP", (panel_a["BP"]["gamma"], panel_a["BP"]["acc"]), xytext=(-8, 8), textcoords="offset points", fontsize=8, fontweight="bold", color=COLORS["BP"]) ax.annotate("FA", (panel_a["FA"]["gamma"], panel_a["FA"]["acc"]), xytext=(8, 5), textcoords="offset points", fontsize=8, fontweight="bold", color=COLORS["FA"]) ax.annotate("DFA", (panel_a["DFA"]["gamma"], panel_a["DFA"]["acc"]), xytext=(8, -8), textcoords="offset points", fontsize=8, fontweight="bold", color=COLORS["DFA"]) ax.set_xlabel("Aggregate $\\Gamma$ (cosine)") ax.set_ylabel("Test accuracy") ax.set_xlim(*x_lim) ax.set_ylim(*y_lim) # ─── PANEL B: Per-layer cosine ────────────────────────────────────────── ax = axes[1] ax.set_title("(B) Per-block cosine $\\cos(a_l, g_l)$", fontsize=9, fontweight="bold", loc="left") blocks = np.arange(4) for method in ["BP", "FA", "DFA"]: vals = panel_b[method] ax.plot(blocks, vals, color=COLORS[method], marker=MARKERS[method], markersize=5, linewidth=1.8, label=method, zorder=3) ax.axhline(0, color="gray", lw=0.5, ls=":", zorder=1) ax.set_xlabel("Block $l$") ax.set_ylabel("$\\cos(a_l,\\, \\nabla_{h_l} \\mathcal{L})$") ax.set_xticks(blocks) ax.set_xticklabels([f"$l={l}$" for l in blocks]) ax.set_ylim(-0.25, 1.12) # ─── PANEL C: Per-layer ||g_l|| ───────────────────────────────────────── ax = axes[2] ax.set_title("(C) Per-layer $\\|g_l\\|$ (BP gradient)", fontsize=9, fontweight="bold", loc="left") layers = np.arange(5) for method in ["BP", "FA", "DFA"]: vals = panel_c[method] ax.semilogy(layers, vals, color=COLORS[method], marker=MARKERS[method], markersize=5, linewidth=1.8, label=method, zorder=3) ax.set_xlabel("Layer $l$") ax.set_ylabel("$\\|\\partial\\mathcal{L}/\\partial h_l\\|_2$ (median)") ax.set_xticks(layers) ax.set_xticklabels([f"$h_{l}$" for l in layers]) ax.set_ylim(5e-12, 5e-2) # ─── GRID (all panels) ─────────────────────────────────────────────────── for ax in axes: ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":") ax.set_axisbelow(True) # Panel C also needs minor grid for log scale axes[2].grid(True, which="minor", color="#e8e8e8", linewidth=0.3, linestyle=":") # ─── CAPTION BOXES below each panel ───────────────────────────────────── captions = [ "With standard reporting pair, FA and\nDFA reached non-trivial accuracy and\npositive cosine alignment in this setting", "Aggregated cosine lies: shallow layers\nof FA and deep layers of DFA are not\nlearning or aligned well", "Reference also fails: DFA collapses to\nnumerical noise at depth, FA decays\n2 orders of magnitude across layers", ] fig.canvas.draw() box_h = 0.13 # height of caption box in figure coords box_gap = 0.12 # gap between axes bottom and box top for i, (ax, txt) in enumerate(zip(axes, captions)): bbox = ax.get_position() bx0 = bbox.x0 bx1 = bbox.x1 by_top = bbox.y0 - box_gap by_bot = by_top - box_h # Draw rounded rectangle fancy = mpatches.FancyBboxPatch( (bx0, by_bot), bx1 - bx0, box_h, boxstyle="round,pad=0.008", facecolor="#f7f7f7", edgecolor="#aaaaaa", linewidth=0.7, transform=fig.transFigure, clip_on=False) fig.patches.append(fancy) # Text centered in the box fig.text((bx0 + bx1) / 2, by_bot + box_h / 2, txt, ha="center", va="center", fontsize=9.5, style="italic", transform=fig.transFigure) # ─── SAVE ──────────────────────────────────────────────────────────────── out = os.path.join(REPO_ROOT, "paper/figures/fig1_audit_hero.pdf") fig.savefig(out, bbox_inches="tight", dpi=300) out_png = out.replace(".pdf", ".png") fig.savefig(out_png, bbox_inches="tight", dpi=200) print(f"Saved: {out}") print(f"Saved: {out_png}")