1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
|
"""Render Figure 4: penalty rescue + capacity-cost control."""
import os
import json
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
REPO_ROOT = "/home/yurenh2/fa"
# Panel A: penalty rescue trajectory
with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f:
snap = json.load(f)
vanilla = snap["dfa_log"]
ep_vanilla = [e["epoch"] for e in vanilla]
hL_vanilla = [e["hidden_norms"][-1] for e in vanilla]
g_vanilla = [e["bp_grad_norms_per_sample_med"][-1] for e in vanilla]
with open(os.path.join(REPO_ROOT, "results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json")) as f:
pen = json.load(f)
ep_pen = [e["epoch"] for e in pen["log"]]
hL_pen = [e["h_L_norm"] for e in pen["log"]]
g_pen = [e["g_2_norm"] for e in pen["log"]]
# Panel B: cosine + rho across vanilla / penalized / fresh-B / penalty lam=1e-4
# Read from existing results
conditions = ["vanilla\nDFA\n(early)", "penalized\n$\\lambda{=}10^{-4}$", "penalized\n$\\lambda{=}10^{-2}$", "fresh-$B$\nnull", "BP grad\n(positive)"]
deep_cos = [-0.008, -0.022, +0.155, +0.002, +1.000]
deep_rho = [-0.003, -0.004, +0.080, +0.006, +0.997]
cos_err = [0.013, 0.0, 0.025, 0.022, 0.0]
rho_err = [0.005, 0.0, 0.011, 0.0, 0.0]
# Panel C: 2x2 capacity-cost control
methods = ["BP", "DFA"]
no_pen = [0.585, 0.301]
with_pen = [0.530, 0.360]
shallow = 0.349
fig, axes = plt.subplots(1, 3, figsize=(13, 6.0))
# Panel A: trajectory
ax = axes[0]
ax.plot(ep_vanilla, hL_vanilla, label="vanilla DFA $\\|h_L\\|$", color="C3", lw=1.5, marker="o", markersize=3)
ax.plot(ep_pen, hL_pen, label="penalized DFA $\\|h_L\\|$ ($\\lambda{=}10^{-2}$)", color="C2", lw=1.5, marker="s", markersize=3)
ax.set_yscale("log")
ax.set_xlabel("epoch", fontsize=10)
ax.set_ylabel("$\\|h_L\\|$ (log)", fontsize=10)
ax.set_title("(a) penalty contains residual stream\n(4 OOM rescue)", fontsize=10)
ax.legend(loc="lower right", fontsize=8)
ax.grid(True, alpha=0.3, which="both")
ax2 = ax.twinx()
ax2.plot(ep_vanilla, g_vanilla, label="vanilla $\\|g\\|$", color="C3", lw=1, ls=":", marker="^", markersize=3)
ax2.plot(ep_pen, g_pen, label="penalized $\\|g\\|$", color="C2", lw=1, ls=":", marker="v", markersize=3)
ax2.axhline(1e-7, color="black", ls="--", lw=0.8, label="$10^{-7}$ floor")
ax2.set_yscale("log")
ax2.set_ylabel("$\\|g_L\\|$ (log)", fontsize=9, color="gray")
ax2.tick_params(axis="y", labelcolor="gray")
# Panel B: cosine + rho
ax = axes[1]
xpos = np.arange(len(conditions))
w = 0.35
b1 = ax.bar(xpos - w/2, deep_cos, w, yerr=cos_err, label="deep cos", color="#4682b4", capsize=3)
b2 = ax.bar(xpos + w/2, deep_rho, w, yerr=rho_err, label="deep $\\rho$", color="#7da76f", capsize=3)
ax.axhline(0, color="black", lw=0.5)
ax.set_xticks(xpos)
ax.set_xticklabels(conditions, fontsize=8)
ax.set_ylabel("deep-layer alignment", fontsize=10)
ax.set_title("(b) two metrics agree across conditions\n(measurement vs random feedback)", fontsize=10)
ax.legend(loc="upper left", fontsize=8)
ax.grid(True, axis="y", alpha=0.3)
ax.set_ylim(-0.1, 1.1)
# Panel C: 2x2 capacity-cost
ax = axes[2]
xpos = np.arange(len(methods))
w = 0.35
ax.bar(xpos - w/2, no_pen, w, label="no penalty", color="#4682b4")
ax.bar(xpos + w/2, with_pen, w, label="with penalty $\\lambda{=}10^{-2}$", color="#cc4444")
ax.axhline(shallow, color="black", ls="--", lw=1, label=f"frozen baseline {shallow}")
ax.set_xticks(xpos)
ax.set_xticklabels(methods, fontsize=10)
ax.set_ylabel("test accuracy", fontsize=10)
ax.set_title("(c) BP+penalty 2$\\times$2 control\n(BP-pen-cost $-5.5$pp; gap $17$pp $=$ credit quality)", fontsize=10)
ax.legend(loc="upper right", fontsize=8)
ax.grid(True, axis="y", alpha=0.3)
ax.set_ylim(0, 0.7)
plt.tight_layout()
out = os.path.join(REPO_ROOT, "paper/figures/fig4_penalty_rescue.pdf")
plt.savefig(out, bbox_inches="tight", dpi=200)
print(f"Saved {out}")
|