summaryrefslogtreecommitdiff
path: root/paper/figures/render_fig3b_crossarch_3row.py
blob: 05d7ad0e7c245fa2393529da7c2ba6ce0051b636 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Figure 3b: Cross-architecture temporal evolution (3 rows × 3 columns = 9 panels).
Row 1: ViT-Mini (terminal LN)
Row 2: ResMLP no terminal LN
Row 3: StudentNet (no LN)
Columns: ||h_L||, ||g_L||, test acc
Methods: BP (blue), FA (orange), DFA (red)
"""
import os, json
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

REPO_ROOT = "/home/yurenh2/fa"
COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"}

plt.rcParams.update({
    "font.size": 9, "axes.labelsize": 10, "axes.titlesize": 10,
    "legend.fontsize": 8, "xtick.labelsize": 8, "ytick.labelsize": 8,
    "font.family": "serif",
})


def extract_series(log):
    epochs = [e['epoch'] for e in log]
    if 'hidden_norms' in log[0]:
        h_L = [e['hidden_norms'][-1] for e in log]
    elif 'hidden_norms_cls' in log[0]:
        h_L = [e['hidden_norms_cls'][-1] for e in log]
    else:
        h_L = [1.0] * len(log)
    if 'bp_grad_norms_per_sample_med' in log[0]:
        g_L = [e['bp_grad_norms_per_sample_med'][-1] for e in log]
    elif 'bp_grad_per_sample_l2_med' in log[0]:
        g_L = [e['bp_grad_per_sample_l2_med'][-1] for e in log]
    else:
        g_L = [1.0] * len(log)
    acc = [e['acc_eval'] for e in log]
    return epochs, h_L, g_L, acc


def add_grid(ax, log_scale=False):
    ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":")
    if log_scale:
        ax.grid(True, which="minor", color="#e8e8e8", linewidth=0.3, linestyle=":")
    ax.set_axisbelow(True)


# Load data
vit = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_vit_v1/snapshot_vit_s42.json")))
fa_vit = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_vit_v1/snapshot_fa_canonical_s42.json")))

noln = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_no_outln_v1/snapshot_noLN_s42.json")))
fa_noln = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_no_outln_v1/snapshot_fa_canonical_noln_s42.json")))

synth = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_synth_v1/snapshot_synth_a1.0_L4_s42.json")))
fa_synth = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_synth_v1/snapshot_fa_canonical_s42.json")))

arch_data = [
    ("ViT-Mini", vit, fa_vit),
    ("ResMLP no-LN", noln, fa_noln),
    ("StudentNet", synth, fa_synth),
]

fig, axes = plt.subplots(3, 3, figsize=(10.5, 7.2))
fig.subplots_adjust(wspace=0.35, hspace=0.40, left=0.10, right=0.97, bottom=0.07, top=0.93)

for row, (arch_name, arch_json, fa_json) in enumerate(arch_data):
    data = {
        "BP": extract_series(arch_json['bp_log']),
        "FA": extract_series(fa_json['fa_log']),
        "DFA": extract_series(arch_json['dfa_log']),
    }

    # Column 0: ||h_L||
    ax = axes[row, 0]
    for m in ["BP", "FA", "DFA"]:
        ep, h, g, a = data[m]
        ax.semilogy(ep, h, color=COLORS[m], linewidth=1.5, label=m)
    ax.set_ylabel("$\\|h_L\\|_2$")
    if row == 0:
        ax.set_title("$\\|h_L\\|$  (residual norm)")
        ax.legend(loc="center right", fontsize=7)
    if row == 2:
        ax.set_xlabel("Epoch")
    add_grid(ax, log_scale=True)

    # Architecture label on the left
    ax.annotate(arch_name, xy=(0, 0.5), xytext=(-55, 0),
                xycoords="axes fraction", textcoords="offset points",
                fontsize=9, fontweight="bold", rotation=90,
                ha="center", va="center")

    # Column 1: ||g_L|| — shared y range across rows for comparison
    ax = axes[row, 1]
    for m in ["BP", "FA", "DFA"]:
        ep, h, g, a = data[m]
        ax.semilogy(ep, g, color=COLORS[m], linewidth=1.5)
    ax.set_ylabel("$\\|g_L\\|_2$")
    ax.set_ylim(1e-12, 5e-2)
    if row == 0:
        ax.set_title("$\\|g_L\\|$  (BP gradient at $h_L$)")
    if row == 2:
        ax.set_xlabel("Epoch")
    add_grid(ax, log_scale=True)

    # Column 2: test acc
    ax = axes[row, 2]
    for m in ["BP", "FA", "DFA"]:
        ep, h, g, a = data[m]
        ax.plot(ep, a, color=COLORS[m], linewidth=1.5)
    ax.set_ylabel("Test accuracy")
    if row == 0:
        ax.set_title("Test accuracy")
    if row == 2:
        ax.set_xlabel("Epoch")
    add_grid(ax)

out = os.path.join(REPO_ROOT, "paper/figures/fig3b_crossarch_3row.pdf")
fig.savefig(out, bbox_inches="tight", dpi=300)
fig.savefig(out.replace(".pdf", ".png"), bbox_inches="tight", dpi=200)
print(f"Saved: {out}")