""" Figure 3b: Cross-architecture temporal evolution (3 rows × 3 columns = 9 panels). Row 1: ViT-Mini (terminal LN) Row 2: ResMLP no terminal LN Row 3: StudentNet (no LN) Columns: ||h_L||, ||g_L||, test acc Methods: BP (blue), FA (orange), DFA (red) """ import os, json import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np REPO_ROOT = "/home/yurenh2/fa" COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"} plt.rcParams.update({ "font.size": 9, "axes.labelsize": 10, "axes.titlesize": 10, "legend.fontsize": 8, "xtick.labelsize": 8, "ytick.labelsize": 8, "font.family": "serif", }) def extract_series(log): epochs = [e['epoch'] for e in log] if 'hidden_norms' in log[0]: h_L = [e['hidden_norms'][-1] for e in log] elif 'hidden_norms_cls' in log[0]: h_L = [e['hidden_norms_cls'][-1] for e in log] else: h_L = [1.0] * len(log) if 'bp_grad_norms_per_sample_med' in log[0]: g_L = [e['bp_grad_norms_per_sample_med'][-1] for e in log] elif 'bp_grad_per_sample_l2_med' in log[0]: g_L = [e['bp_grad_per_sample_l2_med'][-1] for e in log] else: g_L = [1.0] * len(log) acc = [e['acc_eval'] for e in log] return epochs, h_L, g_L, acc def add_grid(ax, log_scale=False): ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":") if log_scale: ax.grid(True, which="minor", color="#e8e8e8", linewidth=0.3, linestyle=":") ax.set_axisbelow(True) # Load data vit = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_vit_v1/snapshot_vit_s42.json"))) fa_vit = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_vit_v1/snapshot_fa_canonical_s42.json"))) noln = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_no_outln_v1/snapshot_noLN_s42.json"))) fa_noln = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_no_outln_v1/snapshot_fa_canonical_noln_s42.json"))) synth = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_synth_v1/snapshot_synth_a1.0_L4_s42.json"))) fa_synth = json.load(open(os.path.join(REPO_ROOT, "results/snapshot_synth_v1/snapshot_fa_canonical_s42.json"))) arch_data = [ ("ViT-Mini", vit, fa_vit), ("ResMLP no-LN", noln, fa_noln), ("StudentNet", synth, fa_synth), ] fig, axes = plt.subplots(3, 3, figsize=(10.5, 7.2)) fig.subplots_adjust(wspace=0.35, hspace=0.40, left=0.10, right=0.97, bottom=0.07, top=0.93) for row, (arch_name, arch_json, fa_json) in enumerate(arch_data): data = { "BP": extract_series(arch_json['bp_log']), "FA": extract_series(fa_json['fa_log']), "DFA": extract_series(arch_json['dfa_log']), } # Column 0: ||h_L|| ax = axes[row, 0] for m in ["BP", "FA", "DFA"]: ep, h, g, a = data[m] ax.semilogy(ep, h, color=COLORS[m], linewidth=1.5, label=m) ax.set_ylabel("$\\|h_L\\|_2$") if row == 0: ax.set_title("$\\|h_L\\|$ (residual norm)") ax.legend(loc="center right", fontsize=7) if row == 2: ax.set_xlabel("Epoch") add_grid(ax, log_scale=True) # Architecture label on the left ax.annotate(arch_name, xy=(0, 0.5), xytext=(-55, 0), xycoords="axes fraction", textcoords="offset points", fontsize=9, fontweight="bold", rotation=90, ha="center", va="center") # Column 1: ||g_L|| — shared y range across rows for comparison ax = axes[row, 1] for m in ["BP", "FA", "DFA"]: ep, h, g, a = data[m] ax.semilogy(ep, g, color=COLORS[m], linewidth=1.5) ax.set_ylabel("$\\|g_L\\|_2$") ax.set_ylim(1e-12, 5e-2) if row == 0: ax.set_title("$\\|g_L\\|$ (BP gradient at $h_L$)") if row == 2: ax.set_xlabel("Epoch") add_grid(ax, log_scale=True) # Column 2: test acc ax = axes[row, 2] for m in ["BP", "FA", "DFA"]: ep, h, g, a = data[m] ax.plot(ep, a, color=COLORS[m], linewidth=1.5) ax.set_ylabel("Test accuracy") if row == 0: ax.set_title("Test accuracy") if row == 2: ax.set_xlabel("Epoch") add_grid(ax) out = os.path.join(REPO_ROOT, "paper/figures/fig3b_crossarch_3row.pdf") fig.savefig(out, bbox_inches="tight", dpi=300) fig.savefig(out.replace(".pdf", ".png"), bbox_inches="tight", dpi=200) print(f"Saved: {out}")