1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
"""
Temporal trajectory of the protocol's (a) and (b) diagnostics under
DFA + residual-branch penalty, vs vanilla DFA.
Uses the existing per-epoch logs:
- vanilla DFA: results/snapshot_evolution_v2/snapshot_evolution_s42.json
- penalized DFA (3 lam values): results/dfa_residual_penalty/dfa_pen_lam{0.001,0.01,0.1}_s42.json
Run:
python -m protocol.examples.temporal_penalty_vs_vanilla
"""
import os
import sys
import json
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
REPO_ROOT = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
def load_vanilla_log():
with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f:
d = json.load(f)
log = d["dfa_log"]
return [
{"epoch": e["epoch"],
"h_L": e["hidden_norms"][-1],
"g_2": e["bp_grad_norms_per_sample_med"][2],
"acc": e["acc_eval"]}
for e in log
]
def load_penalty_log(lam):
path = os.path.join(REPO_ROOT, f"results/dfa_residual_penalty/dfa_pen_lam{lam}_s42.json")
if not os.path.exists(path):
return None
with open(path) as f:
d = json.load(f)
return d["log"]
def main():
vanilla = load_vanilla_log()
p_1e3 = load_penalty_log("0.001")
p_1e2 = load_penalty_log("0.01")
# p_1e1 = load_penalty_log("0.1") # was killed mid-run
fig, axes = plt.subplots(1, 3, figsize=(13, 4.0), gridspec_kw={"wspace": 0.35})
plots = [
("vanilla DFA", vanilla, "h_L", "g_2", "acc", "C3", None, "-", 2),
]
if p_1e3:
plots.append(("DFA + λ=1e-3", p_1e3, "h_L_norm", "g_2_norm", "acc_eval", "C2", "o", "-", 2))
if p_1e2:
plots.append(("DFA + λ=1e-2", p_1e2, "h_L_norm", "g_2_norm", "acc_eval", "C0", "s", "-", 2))
# --- (a) ‖h_L‖ ---
ax = axes[0]
for label, log, hk, gk, ak, c, m, ls, lw in plots:
ax.plot([e["epoch"] for e in log], [e[hk] for e in log],
label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls)
ax.axhline(200, color="gray", linestyle=":", lw=1, label="BP ~200")
ax.set_yscale("log")
ax.set_xlabel("epoch", fontsize=10)
ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10)
ax.set_title("(a) residual stream", fontsize=11)
ax.legend(loc="lower right", fontsize=8)
ax.grid(True, which="both", alpha=0.3)
# --- (b) ‖g_2‖ ---
ax = axes[1]
for label, log, hk, gk, ak, c, m, ls, lw in plots:
ax.plot([e["epoch"] for e in log], [e[gk] for e in log],
label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls)
ax.axhline(1e-7, color="k", linestyle="--", lw=1.2, label="floor 1e-7")
ax.set_yscale("log")
ax.set_xlabel("epoch", fontsize=10)
ax.set_ylabel(r"$\|g_2\|_2$ (log)", fontsize=10)
ax.set_title("(b) BP grad floor", fontsize=11)
ax.legend(loc="lower left", fontsize=8)
ax.grid(True, which="both", alpha=0.3)
# --- (c) accuracy ---
ax = axes[2]
for label, log, hk, gk, ak, c, m, ls, lw in plots:
ax.plot([e["epoch"] for e in log], [e[ak] for e in log],
label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls)
ax.axhline(0.349, color="k", linestyle="--", lw=1.0, label="DFA-shallow 0.349")
ax.axhline(0.371, color="purple", linestyle=":", lw=1.0, label="2pp threshold 0.371")
ax.set_xlabel("epoch", fontsize=10)
ax.set_ylabel("test acc", fontsize=10)
ax.set_title("acc + frozen baseline + (d) threshold", fontsize=11)
ax.legend(loc="lower right", fontsize=7.5)
ax.grid(True, alpha=0.3)
ax.set_ylim(0.05, 0.55)
fig.suptitle(
"Penalty rescue: (a) and (b) cleanly fix; (d) verdict depends on λ choice",
fontsize=11, y=1.05
)
fig.tight_layout()
out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_penalty_lambda_sweep.png")
fig.savefig(out_path, dpi=140, bbox_inches="tight")
print(f"Saved {out_path}")
if __name__ == "__main__":
main()
|