summaryrefslogtreecommitdiff
path: root/paper/figures/render_fig3_temporal.py
diff options
context:
space:
mode:
Diffstat (limited to 'paper/figures/render_fig3_temporal.py')
-rw-r--r--paper/figures/render_fig3_temporal.py192
1 files changed, 192 insertions, 0 deletions
diff --git a/paper/figures/render_fig3_temporal.py b/paper/figures/render_fig3_temporal.py
new file mode 100644
index 0000000..d8d93db
--- /dev/null
+++ b/paper/figures/render_fig3_temporal.py
@@ -0,0 +1,192 @@
+"""
+Render Figure 3: Temporal evolution of diagnostics.
+
+Figure 3a: ResMLP (with terminal LN) — BP, FA, DFA overlaid
+Figure 3b: ViT-Mini + ResMLP-no-outLN — BP, DFA only
+
+Each figure: 1 row per architecture (3a has 1 row, 3b has 2 rows),
+3 columns = ||h_L||, ||g_L||, test acc.
+Methods as colored lines within each panel.
+"""
+import os, json
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import numpy as np
+
+REPO_ROOT = "/home/yurenh2/fa"
+COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"}
+
+plt.rcParams.update({
+ "font.size": 9,
+ "axes.labelsize": 10,
+ "axes.titlesize": 10,
+ "legend.fontsize": 8,
+ "xtick.labelsize": 8,
+ "ytick.labelsize": 8,
+ "font.family": "serif",
+})
+
+
+def extract_series(log):
+ epochs = [e['epoch'] for e in log]
+ # Handle different key names across architectures
+ if 'hidden_norms' in log[0]:
+ h_L = [e['hidden_norms'][-1] for e in log]
+ elif 'hidden_norms_cls' in log[0]:
+ h_L = [e['hidden_norms_cls'][-1] for e in log]
+ else:
+ h_L = [1.0] * len(log)
+ if 'bp_grad_norms_per_sample_med' in log[0]:
+ g_L = [e['bp_grad_norms_per_sample_med'][-1] for e in log]
+ elif 'bp_grad_per_sample_l2_med' in log[0]:
+ g_L = [e['bp_grad_per_sample_l2_med'][-1] for e in log]
+ else:
+ g_L = [1.0] * len(log)
+ acc = [e['acc_eval'] for e in log]
+ return epochs, h_L, g_L, acc
+
+
+def add_grid(ax, log_scale=False):
+ ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":")
+ if log_scale:
+ ax.grid(True, which="minor", color="#e8e8e8", linewidth=0.3, linestyle=":")
+ ax.set_axisbelow(True)
+
+
+# ─── Load data ───────────────────────────────────────────────────────────
+
+# ResMLP (with terminal LN)
+resmlp = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")))
+fa_resmlp = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_fa_s42.json")))
+
+# FA canonical for ResMLP
+fa_resmlp_canonical = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_fa_canonical_s42.json")))
+
+# ViT-Mini
+vit = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_vit_v1/snapshot_vit_s42.json")))
+fa_vit = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_vit_v1/snapshot_fa_canonical_s42.json")))
+
+# StudentNet (synthetic teacher-student, no terminal LN)
+synth = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_synth_v1/snapshot_synth_a1.0_L4_s42.json")))
+fa_synth = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_synth_v1/snapshot_fa_canonical_s42.json")))
+
+
+# ═══════════════════════════════════════════════════════════════════════════
+# Figure 3a: ResMLP — BP / FA / DFA
+# ═══════════════════════════════════════════════════════════════════════════
+
+fig_a, axes_a = plt.subplots(1, 3, figsize=(10.5, 2.8))
+fig_a.subplots_adjust(wspace=0.35, left=0.07, right=0.97, bottom=0.18, top=0.85)
+# No suptitle — user will write caption
+
+data_resmlp = {
+ "BP": extract_series(resmlp['bp_log']),
+ "DFA": extract_series(resmlp['dfa_log']),
+ "FA": extract_series(fa_resmlp_canonical['fa_log']),
+}
+
+# Column 0: ||h_L||
+ax = axes_a[0]
+for method in ["BP", "FA", "DFA"]:
+ ep, h, g, a = data_resmlp[method]
+ ax.semilogy(ep, h, color=COLORS[method], linewidth=1.5, label=method)
+ax.set_ylabel("$\\|h_L\\|_2$")
+ax.set_xlabel("Epoch")
+ax.set_title("$\\|h_L\\|$ (residual norm)")
+ax.legend(loc="center right", fontsize=7)
+add_grid(ax, log_scale=True)
+
+# Column 1: ||g_L||
+ax = axes_a[1]
+for method in ["BP", "FA", "DFA"]:
+ ep, h, g, a = data_resmlp[method]
+ ax.semilogy(ep, g, color=COLORS[method], linewidth=1.5, label=method)
+ax.set_ylabel("$\\|g_L\\|_2$")
+ax.set_xlabel("Epoch")
+ax.set_title("$\\|g_L\\|$ (BP gradient at $h_L$)")
+add_grid(ax, log_scale=True)
+
+# Column 2: test acc
+ax = axes_a[2]
+for method in ["BP", "FA", "DFA"]:
+ ep, h, g, a = data_resmlp[method]
+ ax.plot(ep, a, color=COLORS[method], linewidth=1.5, label=method)
+ax.set_ylabel("Test accuracy")
+ax.set_xlabel("Epoch")
+ax.set_title("Test accuracy")
+ax.set_ylim(0, 0.7)
+add_grid(ax)
+
+out_a = os.path.join(REPO_ROOT, "paper/figures/fig3a_temporal_resmlp.pdf")
+fig_a.savefig(out_a, bbox_inches="tight", dpi=300)
+fig_a.savefig(out_a.replace(".pdf", ".png"), bbox_inches="tight", dpi=200)
+print(f"Saved: {out_a}")
+
+
+# ═══════════════════════════════════════════════════════════════════════════
+# Figure 3b: ViT-Mini + ResMLP-no-outLN — BP / DFA only
+# ═══════════════════════════════════════════════════════════════════════════
+
+fig_b, axes_b = plt.subplots(2, 3, figsize=(10.5, 5.0))
+fig_b.subplots_adjust(wspace=0.35, hspace=0.45, left=0.07, right=0.97, bottom=0.10, top=0.90)
+
+arch_data = [
+ ("ViT-Mini", vit, fa_vit),
+ ("StudentNet", synth, fa_synth),
+]
+
+for row, (arch_name, arch_json, fa_json) in enumerate(arch_data):
+ data = {
+ "BP": extract_series(arch_json['bp_log']),
+ "FA": extract_series(fa_json['fa_log']),
+ "DFA": extract_series(arch_json['dfa_log']),
+ }
+
+ # Column 0: ||h_L||
+ ax = axes_b[row, 0]
+ for method in ["BP", "FA", "DFA"]:
+ ep, h, g, a = data[method]
+ ax.semilogy(ep, h, color=COLORS[method], linewidth=1.5, label=method)
+ ax.set_ylabel("$\\|h_L\\|_2$")
+ if row == 0:
+ ax.set_title("$\\|h_L\\|$ (residual norm)")
+ ax.legend(loc="center right", fontsize=7)
+ if row == 1:
+ ax.set_xlabel("Epoch")
+ # Architecture label on the left
+ ax.annotate(arch_name, xy=(0, 0.5), xytext=(-55, 0),
+ xycoords="axes fraction", textcoords="offset points",
+ fontsize=8, fontweight="bold", rotation=90,
+ ha="center", va="center")
+ add_grid(ax, log_scale=True)
+
+ # Column 1: ||g_L||
+ ax = axes_b[row, 1]
+ for method in ["BP", "FA", "DFA"]:
+ ep, h, g, a = data[method]
+ ax.semilogy(ep, g, color=COLORS[method], linewidth=1.5, label=method)
+ ax.set_ylabel("$\\|g_L\\|_2$")
+ if row == 0:
+ ax.set_title("$\\|g_L\\|$ (BP gradient at $h_L$)")
+ if row == 1:
+ ax.set_xlabel("Epoch")
+ add_grid(ax, log_scale=True)
+
+ # Column 2: test acc
+ ax = axes_b[row, 2]
+ for method in ["BP", "FA", "DFA"]:
+ ep, h, g, a = data[method]
+ ax.plot(ep, a, color=COLORS[method], linewidth=1.5, label=method)
+ ax.set_ylabel("Test accuracy")
+ if row == 0:
+ ax.set_title("Test accuracy")
+ if row == 1:
+ ax.set_xlabel("Epoch")
+ ax.set_ylim(0, 0.85)
+ add_grid(ax)
+
+out_b = os.path.join(REPO_ROOT, "paper/figures/fig3b_temporal_crossarch.pdf")
+fig_b.savefig(out_b, bbox_inches="tight", dpi=300)
+fig_b.savefig(out_b.replace(".pdf", ".png"), bbox_inches="tight", dpi=200)
+print(f"Saved: {out_b}")