summaryrefslogtreecommitdiff
path: root/paper/figures/render_fig_nooutln_temporal.py
diff options
context:
space:
mode:
Diffstat (limited to 'paper/figures/render_fig_nooutln_temporal.py')
-rw-r--r--paper/figures/render_fig_nooutln_temporal.py96
1 files changed, 96 insertions, 0 deletions
diff --git a/paper/figures/render_fig_nooutln_temporal.py b/paper/figures/render_fig_nooutln_temporal.py
new file mode 100644
index 0000000..443b5c1
--- /dev/null
+++ b/paper/figures/render_fig_nooutln_temporal.py
@@ -0,0 +1,96 @@
+"""
+Temporal evolution for ResMLP d=256 L=4 WITHOUT terminal LN.
+Separate figure (ablation control for Mode 1b).
+BP / FA / DFA overlaid, 3 columns = ||h_L||, ||g_L||, acc.
+"""
+import os, json
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import numpy as np
+
+REPO_ROOT = "/home/yurenh2/fa"
+COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"}
+
+plt.rcParams.update({
+ "font.size": 9, "axes.labelsize": 10, "axes.titlesize": 10,
+ "legend.fontsize": 8, "xtick.labelsize": 8, "ytick.labelsize": 8,
+ "font.family": "serif",
+})
+
+
+def extract_series(log):
+ epochs = [e['epoch'] for e in log]
+ if 'hidden_norms' in log[0]:
+ h_L = [e['hidden_norms'][-1] for e in log]
+ else:
+ h_L = [1.0] * len(log)
+ if 'bp_grad_norms_per_sample_med' in log[0]:
+ g_L = [e['bp_grad_norms_per_sample_med'][-1] for e in log]
+ elif 'bp_grad_per_sample_l2_med' in log[0]:
+ g_L = [e['bp_grad_per_sample_l2_med'][-1] for e in log]
+ else:
+ g_L = [1.0] * len(log)
+ acc = [e['acc_eval'] for e in log]
+ return epochs, h_L, g_L, acc
+
+
+noln = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_no_outln_v1/snapshot_noLN_s42.json")))
+
+# Try canonical FA; fall back to BP/DFA only
+fa_path = os.path.join(REPO_ROOT, "results/snapshot_no_outln_v1/snapshot_fa_canonical_noln_s42.json")
+has_fa = os.path.exists(fa_path)
+if has_fa:
+ fa_noln = json.load(open(fa_path))
+
+data = {"BP": extract_series(noln['bp_log']), "DFA": extract_series(noln['dfa_log'])}
+if has_fa:
+ data["FA"] = extract_series(fa_noln['fa_log'])
+methods = ["BP", "FA", "DFA"] if has_fa else ["BP", "DFA"]
+
+fig, axes = plt.subplots(1, 3, figsize=(10.5, 2.8))
+fig.subplots_adjust(wspace=0.35, left=0.07, right=0.97, bottom=0.18, top=0.92)
+
+# Column 0: ||h_L||
+ax = axes[0]
+for m in methods:
+ ep, h, g, a = data[m]
+ ax.semilogy(ep, h, color=COLORS[m], linewidth=1.5, label=m)
+ax.set_ylabel("$\\|h_L\\|_2$")
+ax.set_xlabel("Epoch")
+ax.set_title("$\\|h_L\\|$ (residual norm)")
+ax.legend(loc="center right", fontsize=7)
+ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":")
+ax.grid(True, which="minor", color="#e8e8e8", linewidth=0.3, linestyle=":")
+ax.set_axisbelow(True)
+
+# Column 1: ||g_L||
+ax = axes[1]
+for m in methods:
+ ep, h, g, a = data[m]
+ ax.semilogy(ep, g, color=COLORS[m], linewidth=1.5, label=m)
+ax.set_ylabel("$\\|g_L\\|_2$")
+ax.set_xlabel("Epoch")
+ax.set_title("$\\|g_L\\|$ (BP gradient at $h_L$)")
+ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":")
+ax.grid(True, which="minor", color="#e8e8e8", linewidth=0.3, linestyle=":")
+ax.set_axisbelow(True)
+
+# Column 2: test acc
+ax = axes[2]
+for m in methods:
+ ep, h, g, a = data[m]
+ ax.plot(ep, a, color=COLORS[m], linewidth=1.5, label=m)
+ax.set_ylabel("Test accuracy")
+ax.set_xlabel("Epoch")
+ax.set_title("Test accuracy")
+ax.set_ylim(0, 0.7)
+ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":")
+ax.set_axisbelow(True)
+
+out = os.path.join(REPO_ROOT, "paper/figures/fig_nooutln_temporal.pdf")
+fig.savefig(out, bbox_inches="tight", dpi=300)
+fig.savefig(out.replace(".pdf", ".png"), bbox_inches="tight", dpi=200)
+print(f"Saved: {out}")
+if not has_fa:
+ print("NOTE: FA canonical data not yet available — will re-render when ready")