1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
|
"""
Plot the depth-utility ladder: test accuracy vs number of trainable blocks k,
one curve per method (BP / FA / DFA), one panel per architecture.
Usage:
python experiments/plot_depth_ladder.py
"""
import os, sys, json
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
CONFIGS = [
('results/depth_ladder/ladder_d256_L4_cifar10.json', 'ResMLP d=256, L=4', 4),
('results/depth_ladder/ladder_d512_L2_cifar10.json', 'ResMLP d=512, L=2', 2),
]
METHODS = [('bp', 'BP', 'tab:green', 'o'),
('fa', 'FA', 'tab:orange', 's'),
('dfa', 'DFA', 'tab:red', '^')]
def agg(path, L):
d = json.load(open(path))['results']
out = {}
for m, _, _, _ in METHODS:
ks, mu, sd = [], [], []
for k in range(L + 1):
a = [v['final_acc'] for v in d[m][str(k)].values()]
ks.append(k); mu.append(np.mean(a))
sd.append(np.std(a, ddof=1) if len(a) > 1 else 0.0)
out[m] = (np.array(ks), np.array(mu), np.array(sd))
return out
def main():
fig, axes = plt.subplots(1, len(CONFIGS), figsize=(11, 4.2))
if len(CONFIGS) == 1:
axes = [axes]
for ax, (path, title, L) in zip(axes, CONFIGS):
data = agg(path, L)
for m, label, color, mk in METHODS:
ks, mu, sd = data[m]
ax.errorbar(ks, mu, yerr=sd, marker=mk, color=color, label=label,
capsize=3, lw=2, ms=7)
# frozen baseline reference (k=0, averaged across methods is ~chance-of-readout)
ax.axhline(0.10, ls=':', color='gray', lw=1)
ax.text(0.02, 0.105, 'chance', color='gray', fontsize=8, transform=ax.get_yaxis_transform())
ax.set_xlabel('trainable blocks $k$ (last $k$ of $L$)')
ax.set_ylabel('CIFAR-10 test accuracy')
ax.set_title(title)
ax.set_xticks(range(L + 1))
ax.grid(alpha=0.3)
ax.legend(loc='center right')
fig.suptitle('Depth-utility ladder: does training deeper blocks raise accuracy?', y=1.02)
fig.tight_layout()
out = 'results/depth_ladder/depth_ladder.png'
fig.savefig(out, dpi=150, bbox_inches='tight')
print(f"Saved -> {out}")
if __name__ == '__main__':
main()
|