"""Render Figure 4: penalty rescue + capacity-cost control.""" import os import json import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np REPO_ROOT = "/home/yurenh2/fa" # Panel A: penalty rescue trajectory with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f: snap = json.load(f) vanilla = snap["dfa_log"] ep_vanilla = [e["epoch"] for e in vanilla] hL_vanilla = [e["hidden_norms"][-1] for e in vanilla] g_vanilla = [e["bp_grad_norms_per_sample_med"][-1] for e in vanilla] with open(os.path.join(REPO_ROOT, "results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json")) as f: pen = json.load(f) ep_pen = [e["epoch"] for e in pen["log"]] hL_pen = [e["h_L_norm"] for e in pen["log"]] g_pen = [e["g_2_norm"] for e in pen["log"]] # Panel B: cosine + rho across vanilla / penalized / fresh-B / penalty lam=1e-4 # Read from existing results conditions = ["vanilla\nDFA\n(early)", "penalized\n$\\lambda{=}10^{-4}$", "penalized\n$\\lambda{=}10^{-2}$", "fresh-$B$\nnull", "BP grad\n(positive)"] deep_cos = [-0.008, -0.022, +0.155, +0.002, +1.000] deep_rho = [-0.003, -0.004, +0.080, +0.006, +0.997] cos_err = [0.013, 0.0, 0.025, 0.022, 0.0] rho_err = [0.005, 0.0, 0.011, 0.0, 0.0] # Panel C: 2x2 capacity-cost control methods = ["BP", "DFA"] no_pen = [0.585, 0.301] with_pen = [0.530, 0.360] shallow = 0.349 fig, axes = plt.subplots(1, 3, figsize=(13, 6.0)) # Panel A: trajectory ax = axes[0] ax.plot(ep_vanilla, hL_vanilla, label="vanilla DFA $\\|h_L\\|$", color="C3", lw=1.5, marker="o", markersize=3) ax.plot(ep_pen, hL_pen, label="penalized DFA $\\|h_L\\|$ ($\\lambda{=}10^{-2}$)", color="C2", lw=1.5, marker="s", markersize=3) ax.set_yscale("log") ax.set_xlabel("epoch", fontsize=10) ax.set_ylabel("$\\|h_L\\|$ (log)", fontsize=10) ax.set_title("(a) penalty contains residual stream\n(4 OOM rescue)", fontsize=10) ax.legend(loc="lower right", fontsize=8) ax.grid(True, alpha=0.3, which="both") ax2 = ax.twinx() ax2.plot(ep_vanilla, g_vanilla, label="vanilla $\\|g\\|$", color="C3", lw=1, ls=":", marker="^", markersize=3) ax2.plot(ep_pen, g_pen, label="penalized $\\|g\\|$", color="C2", lw=1, ls=":", marker="v", markersize=3) ax2.axhline(1e-7, color="black", ls="--", lw=0.8, label="$10^{-7}$ floor") ax2.set_yscale("log") ax2.set_ylabel("$\\|g_L\\|$ (log)", fontsize=9, color="gray") ax2.tick_params(axis="y", labelcolor="gray") # Panel B: cosine + rho ax = axes[1] xpos = np.arange(len(conditions)) w = 0.35 b1 = ax.bar(xpos - w/2, deep_cos, w, yerr=cos_err, label="deep cos", color="#4682b4", capsize=3) b2 = ax.bar(xpos + w/2, deep_rho, w, yerr=rho_err, label="deep $\\rho$", color="#7da76f", capsize=3) ax.axhline(0, color="black", lw=0.5) ax.set_xticks(xpos) ax.set_xticklabels(conditions, fontsize=8) ax.set_ylabel("deep-layer alignment", fontsize=10) ax.set_title("(b) two metrics agree across conditions\n(measurement vs random feedback)", fontsize=10) ax.legend(loc="upper left", fontsize=8) ax.grid(True, axis="y", alpha=0.3) ax.set_ylim(-0.1, 1.1) # Panel C: 2x2 capacity-cost ax = axes[2] xpos = np.arange(len(methods)) w = 0.35 ax.bar(xpos - w/2, no_pen, w, label="no penalty", color="#4682b4") ax.bar(xpos + w/2, with_pen, w, label="with penalty $\\lambda{=}10^{-2}$", color="#cc4444") ax.axhline(shallow, color="black", ls="--", lw=1, label=f"frozen baseline {shallow}") ax.set_xticks(xpos) ax.set_xticklabels(methods, fontsize=10) ax.set_ylabel("test accuracy", fontsize=10) ax.set_title("(c) BP+penalty 2$\\times$2 control\n(BP-pen-cost $-5.5$pp; gap $17$pp $=$ credit quality)", fontsize=10) ax.legend(loc="upper right", fontsize=8) ax.grid(True, axis="y", alpha=0.3) ax.set_ylim(0, 0.7) plt.tight_layout() out = os.path.join(REPO_ROOT, "paper/figures/fig4_penalty_rescue.pdf") plt.savefig(out, bbox_inches="tight", dpi=200) print(f"Saved {out}")