1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
"""
Plot the cross-architecture temporal validation result as a single figure
suitable for §3 of the paper. Three columns (one per architecture), three
rows: ‖h_L‖, ‖g_L‖, accuracy. BP and DFA trajectories overlaid with the
diagnostic thresholds drawn as horizontal lines.
Data source: per-epoch snapshot logs already saved in
results/snapshot_evolution_v2/snapshot_evolution_s{seed}.json (ResMLP)
results/snapshot_vit_v1/snapshot_vit_s{seed}.json (ViT-Mini)
results/snapshot_no_outln_v1/snapshot_noLN_s{seed}.json (StudentNet)
This script does NOT use GPU and runs in <5 seconds.
Run:
python -m protocol.examples.plot_temporal_cross_arch --seed 42
"""
import os
import sys
import json
import argparse
import matplotlib
matplotlib.use("Agg") # no display needed
import matplotlib.pyplot as plt
REPO_ROOT = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.insert(0, REPO_ROOT)
def load_snapshot(arch, seed):
if arch == "resmlp":
path = os.path.join(REPO_ROOT, f"results/snapshot_evolution_v2/snapshot_evolution_s{seed}.json")
h_key = "hidden_norms"
g_key = "bp_grad_norms_per_sample_med"
elif arch == "vit":
path = os.path.join(REPO_ROOT, f"results/snapshot_vit_v1/snapshot_vit_s{seed}.json")
h_key = "hidden_norms_cls"
g_key = "bp_grad_per_sample_l2_med"
else: # no_outln
path = os.path.join(REPO_ROOT, f"results/snapshot_no_outln_v1/snapshot_noLN_s{seed}.json")
h_key = "hidden_norms"
g_key = "bp_grad_per_sample_l2_med"
if not os.path.exists(path):
return None
with open(path) as f:
d = json.load(f)
return d, h_key, g_key
def trajectory(log, h_key, g_key):
epochs = [e["epoch"] for e in log]
h_L = [e[h_key][-1] for e in log]
g_L = [e[g_key][-1] for e in log]
acc = [e["acc_eval"] for e in log]
return epochs, h_L, g_L, acc
def main():
p = argparse.ArgumentParser()
p.add_argument("--seed", type=int, default=42)
args = p.parse_args()
arches = [
("resmlp", "ResMLP (with terminal LN)"),
("vit", "ViT-Mini (cls + LN)"),
("no_outln", "StudentNet (no terminal LN)"),
]
fig, axes = plt.subplots(3, 3, figsize=(13, 9), sharex=False)
for col, (arch, label) in enumerate(arches):
loaded = load_snapshot(arch, args.seed)
if loaded is None:
for r in range(3):
axes[r, col].set_visible(False)
continue
d, h_key, g_key = loaded
bp_ep, bp_h, bp_g, bp_a = trajectory(d["bp_log"], h_key, g_key)
dfa_ep, dfa_h, dfa_g, dfa_a = trajectory(d["dfa_log"], h_key, g_key)
# Row 0: ||h_L||
ax = axes[0, col]
ax.plot(bp_ep, bp_h, label="BP", color="C0", lw=2)
ax.plot(dfa_ep, dfa_h, label="DFA", color="C3", lw=2)
ax.set_yscale("log")
ax.set_title(label, fontsize=11)
if col == 0:
ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10)
ax.legend(loc="upper left", fontsize=8)
ax.grid(True, which="both", alpha=0.3)
# Row 1: ||g_L|| with threshold line
ax = axes[1, col]
ax.plot(bp_ep, bp_g, label="BP", color="C0", lw=2)
ax.plot(dfa_ep, dfa_g, label="DFA", color="C3", lw=2)
ax.axhline(1e-7, color="black", linestyle="--", lw=1, label=r"floor $10^{-7}$")
ax.set_yscale("log")
if col == 0:
ax.set_ylabel(r"$\|g_L\|_2$ (log)", fontsize=10)
ax.legend(loc="upper right", fontsize=8)
ax.grid(True, which="both", alpha=0.3)
# Row 2: accuracy
ax = axes[2, col]
ax.plot(bp_ep, bp_a, label="BP", color="C0", lw=2)
ax.plot(dfa_ep, dfa_a, label="DFA", color="C3", lw=2)
if col == 0:
ax.set_ylabel("test acc", fontsize=10)
ax.set_xlabel("epoch", fontsize=10)
ax.legend(loc="lower right", fontsize=8)
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)
fig.suptitle(
f"Cross-architecture temporal evolution of FA diagnostics (seed {args.seed})",
fontsize=12, y=1.0
)
fig.tight_layout()
out_path = os.path.join(REPO_ROOT, f"results/protocol_audit/figure_cross_arch_temporal_s{args.seed}.png")
fig.savefig(out_path, dpi=140, bbox_inches="tight")
print(f"Saved {out_path}")
if __name__ == "__main__":
main()
|