""" Generate the snapshot-evolution figure(s) for the paper from existing JSONs. Produces: - figure_snapshot_resmlp.pdf : ResMLP with vs without out_ln, ||h_L|| and ||g|| over epochs for BP and DFA - figure_snapshot_vit.pdf : ViT-Mini ||h_L|| and ||g|| over epochs for BP/DFA Usage: python experiments/figure_snapshot_evolution.py """ import os, sys, json import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt def load_log(path, log_key): if not os.path.exists(path): return None with open(path) as f: return json.load(f).get(log_key) def trajectory(log, metric): """Extract a per-epoch trajectory for the given metric.""" eps = [r['epoch'] for r in log] if metric == 'h_L': # last hidden norm — handles both ResMLP (hidden_norms) and ViT (hidden_norms_cls) values = [] for r in log: if 'hidden_norms_cls' in r: values.append(r['hidden_norms_cls'][-1]) else: values.append(r['hidden_norms'][-1]) elif metric == 'g_2': values = [] for r in log: key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in r else 'bp_grad_norms_per_sample_med' values.append(r[key][2]) elif metric == 'acc': values = [r['acc_eval'] for r in log] elif metric == 'gamma_dfa': values = [r.get('gamma_dfa', float('nan')) for r in log] else: return None, None return np.array(eps), np.array(values) def make_resmlp_figure(out_path): fig, axes = plt.subplots(2, 2, figsize=(10, 7), sharex=True) runs = { 'with_out_ln_s42': 'results/snapshot_evolution_v2/snapshot_evolution_s42.json', 'no_out_ln_s42': 'results/snapshot_no_outln_v1/snapshot_noLN_s42.json', 'no_out_ln_s123': 'results/snapshot_no_outln_v1/snapshot_noLN_s123.json', 'no_out_ln_s456': 'results/snapshot_no_outln_v1/snapshot_noLN_s456.json', } runs_loaded = {k: (load_log(v, 'bp_log'), load_log(v, 'dfa_log')) for k, v in runs.items()} # Top row: with out_ln bp, dfa = runs_loaded['with_out_ln_s42'] ax = axes[0, 0] e, v = trajectory(bp, 'h_L'); ax.plot(e, v, 'b-', label='BP', lw=2) e, v = trajectory(dfa, 'h_L'); ax.plot(e, v, 'r-', label='DFA', lw=2) ax.set_yscale('log'); ax.set_ylabel(r'$\|h_L\|_2$ (median)') ax.set_title('ResMLP with terminal LayerNorm (s42)') ax.legend(); ax.grid(True, alpha=0.3) ax = axes[0, 1] e, v = trajectory(bp, 'g_2'); ax.plot(e, v, 'b-', label='BP', lw=2) e, v = trajectory(dfa, 'g_2'); ax.plot(e, v, 'r-', label='DFA', lw=2) ax.set_yscale('log'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)') ax.set_title('ResMLP with terminal LayerNorm (s42)') ax.legend(); ax.grid(True, alpha=0.3) # Bottom row: no out_ln, mean ± std across 3 seeds no_ln_bp_h = []; no_ln_bp_g = []; no_ln_dfa_h = []; no_ln_dfa_g = [] for k in ['no_out_ln_s42', 'no_out_ln_s123', 'no_out_ln_s456']: bp, dfa = runs_loaded[k] if bp is None or dfa is None: continue e_bp, h_bp = trajectory(bp, 'h_L'); _, g_bp = trajectory(bp, 'g_2') e_dfa, h_dfa = trajectory(dfa, 'h_L'); _, g_dfa = trajectory(dfa, 'g_2') no_ln_bp_h.append(h_bp); no_ln_bp_g.append(g_bp) no_ln_dfa_h.append(h_dfa); no_ln_dfa_g.append(g_dfa) if no_ln_bp_h: eps = e_bp bp_h_arr = np.array(no_ln_bp_h) bp_g_arr = np.array(no_ln_bp_g) dfa_h_arr = np.array(no_ln_dfa_h) dfa_g_arr = np.array(no_ln_dfa_g) ax = axes[1, 0] ax.plot(eps, np.mean(bp_h_arr, 0), 'b-', label='BP', lw=2) ax.fill_between(eps, np.mean(bp_h_arr, 0)-np.std(bp_h_arr, 0), np.mean(bp_h_arr, 0)+np.std(bp_h_arr, 0), color='b', alpha=0.2) ax.plot(eps, np.mean(dfa_h_arr, 0), 'r-', label='DFA', lw=2) ax.fill_between(eps, np.mean(dfa_h_arr, 0)-np.std(dfa_h_arr, 0), np.mean(dfa_h_arr, 0)+np.std(dfa_h_arr, 0), color='r', alpha=0.2) ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|h_L\|_2$ (median)') ax.set_title(f'ResMLP WITHOUT terminal LayerNorm (mean ± std, n={len(no_ln_bp_h)})') ax.legend(); ax.grid(True, alpha=0.3) ax = axes[1, 1] ax.plot(eps, np.mean(bp_g_arr, 0), 'b-', label='BP', lw=2) ax.fill_between(eps, np.mean(bp_g_arr, 0)-np.std(bp_g_arr, 0), np.mean(bp_g_arr, 0)+np.std(bp_g_arr, 0), color='b', alpha=0.2) ax.plot(eps, np.mean(dfa_g_arr, 0), 'r-', label='DFA', lw=2) ax.fill_between(eps, np.mean(dfa_g_arr, 0)-np.std(dfa_g_arr, 0), np.mean(dfa_g_arr, 0)+np.std(dfa_g_arr, 0), color='r', alpha=0.2) ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)') ax.set_title(f'ResMLP WITHOUT terminal LayerNorm (mean ± std, n={len(no_ln_bp_h)})') ax.legend(); ax.grid(True, alpha=0.3) plt.suptitle('Snapshot evolution: residual stream + BP grad over training\n(top: with terminal LN — DFA explodes; bottom: no terminal LN — DFA still grows but BP grad does NOT collapse)', y=1.02) plt.tight_layout() plt.savefig(out_path, bbox_inches='tight', dpi=150) print(f"Saved {out_path}") plt.close() def make_vit_figure(out_path): fig, axes = plt.subplots(1, 2, figsize=(11, 4)) runs = sorted([ f for f in os.listdir('results/snapshot_vit_v1') if f.startswith('snapshot_vit_s') and f.endswith('.json') ]) if not runs: print("No ViT snapshot JSONs found") return bp_h_list = []; bp_g_list = []; dfa_h_list = []; dfa_g_list = [] eps = None for r in runs: path = f'results/snapshot_vit_v1/{r}' bp = load_log(path, 'bp_log') dfa = load_log(path, 'dfa_log') if bp is None or dfa is None: continue e_bp, h_bp = trajectory(bp, 'h_L'); _, g_bp = trajectory(bp, 'g_2') e_dfa, h_dfa = trajectory(dfa, 'h_L'); _, g_dfa = trajectory(dfa, 'g_2') bp_h_list.append(h_bp); bp_g_list.append(g_bp) dfa_h_list.append(h_dfa); dfa_g_list.append(g_dfa) eps = e_bp bp_h_arr = np.array(bp_h_list); bp_g_arr = np.array(bp_g_list) dfa_h_arr = np.array(dfa_h_list); dfa_g_arr = np.array(dfa_g_list) ax = axes[0] ax.plot(eps, np.mean(bp_h_arr, 0), 'b-', label='BP', lw=2) if len(bp_h_list) > 1: ax.fill_between(eps, np.mean(bp_h_arr, 0)-np.std(bp_h_arr, 0), np.mean(bp_h_arr, 0)+np.std(bp_h_arr, 0), color='b', alpha=0.2) ax.plot(eps, np.mean(dfa_h_arr, 0), 'r-', label='DFA', lw=2) if len(dfa_h_list) > 1: ax.fill_between(eps, np.mean(dfa_h_arr, 0)-np.std(dfa_h_arr, 0), np.mean(dfa_h_arr, 0)+np.std(dfa_h_arr, 0), color='r', alpha=0.2) ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|h_L^{cls}\|_2$ (median)') ax.set_title(f'ViT-Mini, terminal LayerNorm (n={len(bp_h_list)})') ax.legend(); ax.grid(True, alpha=0.3) ax = axes[1] ax.plot(eps, np.mean(bp_g_arr, 0), 'b-', label='BP', lw=2) if len(bp_g_list) > 1: ax.fill_between(eps, np.mean(bp_g_arr, 0)-np.std(bp_g_arr, 0), np.mean(bp_g_arr, 0)+np.std(bp_g_arr, 0), color='b', alpha=0.2) ax.plot(eps, np.mean(dfa_g_arr, 0), 'r-', label='DFA', lw=2) if len(dfa_g_list) > 1: ax.fill_between(eps, np.mean(dfa_g_arr, 0)-np.std(dfa_g_arr, 0), np.mean(dfa_g_arr, 0)+np.std(dfa_g_arr, 0), color='r', alpha=0.2) ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)') ax.set_title(f'ViT-Mini, terminal LayerNorm (n={len(bp_g_list)})') ax.legend(); ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(out_path, bbox_inches='tight', dpi=150) print(f"Saved {out_path}") plt.close() if __name__ == '__main__': os.makedirs('results/figures', exist_ok=True) make_resmlp_figure('results/figures/figure_snapshot_resmlp.pdf') make_vit_figure('results/figures/figure_snapshot_vit.pdf')