diff options
Diffstat (limited to 'paper/figures/render_fig_nooutln_temporal.py')
| -rw-r--r-- | paper/figures/render_fig_nooutln_temporal.py | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/paper/figures/render_fig_nooutln_temporal.py b/paper/figures/render_fig_nooutln_temporal.py new file mode 100644 index 0000000..443b5c1 --- /dev/null +++ b/paper/figures/render_fig_nooutln_temporal.py @@ -0,0 +1,96 @@ +""" +Temporal evolution for ResMLP d=256 L=4 WITHOUT terminal LN. +Separate figure (ablation control for Mode 1b). +BP / FA / DFA overlaid, 3 columns = ||h_L||, ||g_L||, acc. +""" +import os, json +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + +REPO_ROOT = "/home/yurenh2/fa" +COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"} + +plt.rcParams.update({ + "font.size": 9, "axes.labelsize": 10, "axes.titlesize": 10, + "legend.fontsize": 8, "xtick.labelsize": 8, "ytick.labelsize": 8, + "font.family": "serif", +}) + + +def extract_series(log): + epochs = [e['epoch'] for e in log] + if 'hidden_norms' in log[0]: + h_L = [e['hidden_norms'][-1] for e in log] + else: + h_L = [1.0] * len(log) + if 'bp_grad_norms_per_sample_med' in log[0]: + g_L = [e['bp_grad_norms_per_sample_med'][-1] for e in log] + elif 'bp_grad_per_sample_l2_med' in log[0]: + g_L = [e['bp_grad_per_sample_l2_med'][-1] for e in log] + else: + g_L = [1.0] * len(log) + acc = [e['acc_eval'] for e in log] + return epochs, h_L, g_L, acc + + +noln = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_no_outln_v1/snapshot_noLN_s42.json"))) + +# Try canonical FA; fall back to BP/DFA only +fa_path = os.path.join(REPO_ROOT, "results/snapshot_no_outln_v1/snapshot_fa_canonical_noln_s42.json") +has_fa = os.path.exists(fa_path) +if has_fa: + fa_noln = json.load(open(fa_path)) + +data = {"BP": extract_series(noln['bp_log']), "DFA": extract_series(noln['dfa_log'])} +if has_fa: + data["FA"] = extract_series(fa_noln['fa_log']) +methods = ["BP", "FA", "DFA"] if has_fa else ["BP", "DFA"] + +fig, axes = plt.subplots(1, 3, figsize=(10.5, 2.8)) +fig.subplots_adjust(wspace=0.35, left=0.07, right=0.97, bottom=0.18, top=0.92) + +# Column 0: ||h_L|| +ax = axes[0] +for m in methods: + ep, h, g, a = data[m] + ax.semilogy(ep, h, color=COLORS[m], linewidth=1.5, label=m) +ax.set_ylabel("$\\|h_L\\|_2$") +ax.set_xlabel("Epoch") +ax.set_title("$\\|h_L\\|$ (residual norm)") +ax.legend(loc="center right", fontsize=7) +ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":") +ax.grid(True, which="minor", color="#e8e8e8", linewidth=0.3, linestyle=":") +ax.set_axisbelow(True) + +# Column 1: ||g_L|| +ax = axes[1] +for m in methods: + ep, h, g, a = data[m] + ax.semilogy(ep, g, color=COLORS[m], linewidth=1.5, label=m) +ax.set_ylabel("$\\|g_L\\|_2$") +ax.set_xlabel("Epoch") +ax.set_title("$\\|g_L\\|$ (BP gradient at $h_L$)") +ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":") +ax.grid(True, which="minor", color="#e8e8e8", linewidth=0.3, linestyle=":") +ax.set_axisbelow(True) + +# Column 2: test acc +ax = axes[2] +for m in methods: + ep, h, g, a = data[m] + ax.plot(ep, a, color=COLORS[m], linewidth=1.5, label=m) +ax.set_ylabel("Test accuracy") +ax.set_xlabel("Epoch") +ax.set_title("Test accuracy") +ax.set_ylim(0, 0.7) +ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":") +ax.set_axisbelow(True) + +out = os.path.join(REPO_ROOT, "paper/figures/fig_nooutln_temporal.pdf") +fig.savefig(out, bbox_inches="tight", dpi=300) +fig.savefig(out.replace(".pdf", ".png"), bbox_inches="tight", dpi=200) +print(f"Saved: {out}") +if not has_fa: + print("NOTE: FA canonical data not yet available — will re-render when ready") |
