diff options
Diffstat (limited to 'paper/figures/render_fig3_temporal.py')
| -rw-r--r-- | paper/figures/render_fig3_temporal.py | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/paper/figures/render_fig3_temporal.py b/paper/figures/render_fig3_temporal.py new file mode 100644 index 0000000..d8d93db --- /dev/null +++ b/paper/figures/render_fig3_temporal.py @@ -0,0 +1,192 @@ +""" +Render Figure 3: Temporal evolution of diagnostics. + +Figure 3a: ResMLP (with terminal LN) — BP, FA, DFA overlaid +Figure 3b: ViT-Mini + ResMLP-no-outLN — BP, DFA only + +Each figure: 1 row per architecture (3a has 1 row, 3b has 2 rows), +3 columns = ||h_L||, ||g_L||, test acc. +Methods as colored lines within each panel. +""" +import os, json +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + +REPO_ROOT = "/home/yurenh2/fa" +COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"} + +plt.rcParams.update({ + "font.size": 9, + "axes.labelsize": 10, + "axes.titlesize": 10, + "legend.fontsize": 8, + "xtick.labelsize": 8, + "ytick.labelsize": 8, + "font.family": "serif", +}) + + +def extract_series(log): + epochs = [e['epoch'] for e in log] + # Handle different key names across architectures + if 'hidden_norms' in log[0]: + h_L = [e['hidden_norms'][-1] for e in log] + elif 'hidden_norms_cls' in log[0]: + h_L = [e['hidden_norms_cls'][-1] for e in log] + else: + h_L = [1.0] * len(log) + if 'bp_grad_norms_per_sample_med' in log[0]: + g_L = [e['bp_grad_norms_per_sample_med'][-1] for e in log] + elif 'bp_grad_per_sample_l2_med' in log[0]: + g_L = [e['bp_grad_per_sample_l2_med'][-1] for e in log] + else: + g_L = [1.0] * len(log) + acc = [e['acc_eval'] for e in log] + return epochs, h_L, g_L, acc + + +def add_grid(ax, log_scale=False): + ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":") + if log_scale: + ax.grid(True, which="minor", color="#e8e8e8", linewidth=0.3, linestyle=":") + ax.set_axisbelow(True) + + +# ─── Load data ─────────────────────────────────────────────────────────── + +# ResMLP (with terminal LN) +resmlp = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json"))) +fa_resmlp = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_fa_s42.json"))) + +# FA canonical for ResMLP +fa_resmlp_canonical = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_fa_canonical_s42.json"))) + +# ViT-Mini +vit = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_vit_v1/snapshot_vit_s42.json"))) +fa_vit = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_vit_v1/snapshot_fa_canonical_s42.json"))) + +# StudentNet (synthetic teacher-student, no terminal LN) +synth = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_synth_v1/snapshot_synth_a1.0_L4_s42.json"))) +fa_synth = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_synth_v1/snapshot_fa_canonical_s42.json"))) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Figure 3a: ResMLP — BP / FA / DFA +# ═══════════════════════════════════════════════════════════════════════════ + +fig_a, axes_a = plt.subplots(1, 3, figsize=(10.5, 2.8)) +fig_a.subplots_adjust(wspace=0.35, left=0.07, right=0.97, bottom=0.18, top=0.85) +# No suptitle — user will write caption + +data_resmlp = { + "BP": extract_series(resmlp['bp_log']), + "DFA": extract_series(resmlp['dfa_log']), + "FA": extract_series(fa_resmlp_canonical['fa_log']), +} + +# Column 0: ||h_L|| +ax = axes_a[0] +for method in ["BP", "FA", "DFA"]: + ep, h, g, a = data_resmlp[method] + ax.semilogy(ep, h, color=COLORS[method], linewidth=1.5, label=method) +ax.set_ylabel("$\\|h_L\\|_2$") +ax.set_xlabel("Epoch") +ax.set_title("$\\|h_L\\|$ (residual norm)") +ax.legend(loc="center right", fontsize=7) +add_grid(ax, log_scale=True) + +# Column 1: ||g_L|| +ax = axes_a[1] +for method in ["BP", "FA", "DFA"]: + ep, h, g, a = data_resmlp[method] + ax.semilogy(ep, g, color=COLORS[method], linewidth=1.5, label=method) +ax.set_ylabel("$\\|g_L\\|_2$") +ax.set_xlabel("Epoch") +ax.set_title("$\\|g_L\\|$ (BP gradient at $h_L$)") +add_grid(ax, log_scale=True) + +# Column 2: test acc +ax = axes_a[2] +for method in ["BP", "FA", "DFA"]: + ep, h, g, a = data_resmlp[method] + ax.plot(ep, a, color=COLORS[method], linewidth=1.5, label=method) +ax.set_ylabel("Test accuracy") +ax.set_xlabel("Epoch") +ax.set_title("Test accuracy") +ax.set_ylim(0, 0.7) +add_grid(ax) + +out_a = os.path.join(REPO_ROOT, "paper/figures/fig3a_temporal_resmlp.pdf") +fig_a.savefig(out_a, bbox_inches="tight", dpi=300) +fig_a.savefig(out_a.replace(".pdf", ".png"), bbox_inches="tight", dpi=200) +print(f"Saved: {out_a}") + + +# ═══════════════════════════════════════════════════════════════════════════ +# Figure 3b: ViT-Mini + ResMLP-no-outLN — BP / DFA only +# ═══════════════════════════════════════════════════════════════════════════ + +fig_b, axes_b = plt.subplots(2, 3, figsize=(10.5, 5.0)) +fig_b.subplots_adjust(wspace=0.35, hspace=0.45, left=0.07, right=0.97, bottom=0.10, top=0.90) + +arch_data = [ + ("ViT-Mini", vit, fa_vit), + ("StudentNet", synth, fa_synth), +] + +for row, (arch_name, arch_json, fa_json) in enumerate(arch_data): + data = { + "BP": extract_series(arch_json['bp_log']), + "FA": extract_series(fa_json['fa_log']), + "DFA": extract_series(arch_json['dfa_log']), + } + + # Column 0: ||h_L|| + ax = axes_b[row, 0] + for method in ["BP", "FA", "DFA"]: + ep, h, g, a = data[method] + ax.semilogy(ep, h, color=COLORS[method], linewidth=1.5, label=method) + ax.set_ylabel("$\\|h_L\\|_2$") + if row == 0: + ax.set_title("$\\|h_L\\|$ (residual norm)") + ax.legend(loc="center right", fontsize=7) + if row == 1: + ax.set_xlabel("Epoch") + # Architecture label on the left + ax.annotate(arch_name, xy=(0, 0.5), xytext=(-55, 0), + xycoords="axes fraction", textcoords="offset points", + fontsize=8, fontweight="bold", rotation=90, + ha="center", va="center") + add_grid(ax, log_scale=True) + + # Column 1: ||g_L|| + ax = axes_b[row, 1] + for method in ["BP", "FA", "DFA"]: + ep, h, g, a = data[method] + ax.semilogy(ep, g, color=COLORS[method], linewidth=1.5, label=method) + ax.set_ylabel("$\\|g_L\\|_2$") + if row == 0: + ax.set_title("$\\|g_L\\|$ (BP gradient at $h_L$)") + if row == 1: + ax.set_xlabel("Epoch") + add_grid(ax, log_scale=True) + + # Column 2: test acc + ax = axes_b[row, 2] + for method in ["BP", "FA", "DFA"]: + ep, h, g, a = data[method] + ax.plot(ep, a, color=COLORS[method], linewidth=1.5, label=method) + ax.set_ylabel("Test accuracy") + if row == 0: + ax.set_title("Test accuracy") + if row == 1: + ax.set_xlabel("Epoch") + ax.set_ylim(0, 0.85) + add_grid(ax) + +out_b = os.path.join(REPO_ROOT, "paper/figures/fig3b_temporal_crossarch.pdf") +fig_b.savefig(out_b, bbox_inches="tight", dpi=300) +fig_b.savefig(out_b.replace(".pdf", ".png"), bbox_inches="tight", dpi=200) +print(f"Saved: {out_b}") |
