summaryrefslogtreecommitdiff
path: root/experiments/plot_depth_ladder.py
blob: a5709bfb597a3d8f7cf3f6651612cf98ed960006 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
Plot the depth-utility ladder: test accuracy vs number of trainable blocks k,
one curve per method (BP / FA / DFA), one panel per architecture.

Usage:
    python experiments/plot_depth_ladder.py
"""
import os, sys, json
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

CONFIGS = [
    ('results/depth_ladder/ladder_d256_L4_cifar10.json', 'ResMLP d=256, L=4', 4),
    ('results/depth_ladder/ladder_d512_L2_cifar10.json', 'ResMLP d=512, L=2', 2),
]
METHODS = [('bp', 'BP', 'tab:green', 'o'),
           ('fa', 'FA', 'tab:orange', 's'),
           ('dfa', 'DFA', 'tab:red', '^')]


def agg(path, L):
    d = json.load(open(path))['results']
    out = {}
    for m, _, _, _ in METHODS:
        ks, mu, sd = [], [], []
        for k in range(L + 1):
            a = [v['final_acc'] for v in d[m][str(k)].values()]
            ks.append(k); mu.append(np.mean(a))
            sd.append(np.std(a, ddof=1) if len(a) > 1 else 0.0)
        out[m] = (np.array(ks), np.array(mu), np.array(sd))
    return out


def main():
    fig, axes = plt.subplots(1, len(CONFIGS), figsize=(11, 4.2))
    if len(CONFIGS) == 1:
        axes = [axes]
    for ax, (path, title, L) in zip(axes, CONFIGS):
        data = agg(path, L)
        for m, label, color, mk in METHODS:
            ks, mu, sd = data[m]
            ax.errorbar(ks, mu, yerr=sd, marker=mk, color=color, label=label,
                        capsize=3, lw=2, ms=7)
        # frozen baseline reference (k=0, averaged across methods is ~chance-of-readout)
        ax.axhline(0.10, ls=':', color='gray', lw=1)
        ax.text(0.02, 0.105, 'chance', color='gray', fontsize=8, transform=ax.get_yaxis_transform())
        ax.set_xlabel('trainable blocks $k$ (last $k$ of $L$)')
        ax.set_ylabel('CIFAR-10 test accuracy')
        ax.set_title(title)
        ax.set_xticks(range(L + 1))
        ax.grid(alpha=0.3)
        ax.legend(loc='center right')
    fig.suptitle('Depth-utility ladder: does training deeper blocks raise accuracy?', y=1.02)
    fig.tight_layout()
    out = 'results/depth_ladder/depth_ladder.png'
    fig.savefig(out, dpi=150, bbox_inches='tight')
    print(f"Saved -> {out}")


if __name__ == '__main__':
    main()