diff options
Diffstat (limited to 'experiments/figure_snapshot_evolution.py')
| -rw-r--r-- | experiments/figure_snapshot_evolution.py | 178 |
1 files changed, 178 insertions, 0 deletions
diff --git a/experiments/figure_snapshot_evolution.py b/experiments/figure_snapshot_evolution.py new file mode 100644 index 0000000..b06f417 --- /dev/null +++ b/experiments/figure_snapshot_evolution.py @@ -0,0 +1,178 @@ +""" +Generate the snapshot-evolution figure(s) for the paper from existing JSONs. + +Produces: + - figure_snapshot_resmlp.pdf : ResMLP with vs without out_ln, ||h_L|| and ||g|| + over epochs for BP and DFA + - figure_snapshot_vit.pdf : ViT-Mini ||h_L|| and ||g|| over epochs for BP/DFA + +Usage: + python experiments/figure_snapshot_evolution.py +""" +import os, sys, json +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + + +def load_log(path, log_key): + if not os.path.exists(path): + return None + with open(path) as f: + return json.load(f).get(log_key) + + +def trajectory(log, metric): + """Extract a per-epoch trajectory for the given metric.""" + eps = [r['epoch'] for r in log] + if metric == 'h_L': + # last hidden norm — handles both ResMLP (hidden_norms) and ViT (hidden_norms_cls) + values = [] + for r in log: + if 'hidden_norms_cls' in r: + values.append(r['hidden_norms_cls'][-1]) + else: + values.append(r['hidden_norms'][-1]) + elif metric == 'g_2': + values = [] + for r in log: + key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in r else 'bp_grad_norms_per_sample_med' + values.append(r[key][2]) + elif metric == 'acc': + values = [r['acc_eval'] for r in log] + elif metric == 'gamma_dfa': + values = [r.get('gamma_dfa', float('nan')) for r in log] + else: + return None, None + return np.array(eps), np.array(values) + + +def make_resmlp_figure(out_path): + fig, axes = plt.subplots(2, 2, figsize=(10, 7), sharex=True) + + runs = { + 'with_out_ln_s42': 'results/snapshot_evolution_v2/snapshot_evolution_s42.json', + 'no_out_ln_s42': 'results/snapshot_no_outln_v1/snapshot_noLN_s42.json', + 'no_out_ln_s123': 'results/snapshot_no_outln_v1/snapshot_noLN_s123.json', + 'no_out_ln_s456': 'results/snapshot_no_outln_v1/snapshot_noLN_s456.json', + } + runs_loaded = {k: (load_log(v, 'bp_log'), load_log(v, 'dfa_log')) for k, v in runs.items()} + + # Top row: with out_ln + bp, dfa = runs_loaded['with_out_ln_s42'] + ax = axes[0, 0] + e, v = trajectory(bp, 'h_L'); ax.plot(e, v, 'b-', label='BP', lw=2) + e, v = trajectory(dfa, 'h_L'); ax.plot(e, v, 'r-', label='DFA', lw=2) + ax.set_yscale('log'); ax.set_ylabel(r'$\|h_L\|_2$ (median)') + ax.set_title('ResMLP with terminal LayerNorm (s42)') + ax.legend(); ax.grid(True, alpha=0.3) + + ax = axes[0, 1] + e, v = trajectory(bp, 'g_2'); ax.plot(e, v, 'b-', label='BP', lw=2) + e, v = trajectory(dfa, 'g_2'); ax.plot(e, v, 'r-', label='DFA', lw=2) + ax.set_yscale('log'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)') + ax.set_title('ResMLP with terminal LayerNorm (s42)') + ax.legend(); ax.grid(True, alpha=0.3) + + # Bottom row: no out_ln, mean ± std across 3 seeds + no_ln_bp_h = []; no_ln_bp_g = []; no_ln_dfa_h = []; no_ln_dfa_g = [] + for k in ['no_out_ln_s42', 'no_out_ln_s123', 'no_out_ln_s456']: + bp, dfa = runs_loaded[k] + if bp is None or dfa is None: continue + e_bp, h_bp = trajectory(bp, 'h_L'); _, g_bp = trajectory(bp, 'g_2') + e_dfa, h_dfa = trajectory(dfa, 'h_L'); _, g_dfa = trajectory(dfa, 'g_2') + no_ln_bp_h.append(h_bp); no_ln_bp_g.append(g_bp) + no_ln_dfa_h.append(h_dfa); no_ln_dfa_g.append(g_dfa) + + if no_ln_bp_h: + eps = e_bp + bp_h_arr = np.array(no_ln_bp_h) + bp_g_arr = np.array(no_ln_bp_g) + dfa_h_arr = np.array(no_ln_dfa_h) + dfa_g_arr = np.array(no_ln_dfa_g) + + ax = axes[1, 0] + ax.plot(eps, np.mean(bp_h_arr, 0), 'b-', label='BP', lw=2) + ax.fill_between(eps, np.mean(bp_h_arr, 0)-np.std(bp_h_arr, 0), np.mean(bp_h_arr, 0)+np.std(bp_h_arr, 0), color='b', alpha=0.2) + ax.plot(eps, np.mean(dfa_h_arr, 0), 'r-', label='DFA', lw=2) + ax.fill_between(eps, np.mean(dfa_h_arr, 0)-np.std(dfa_h_arr, 0), np.mean(dfa_h_arr, 0)+np.std(dfa_h_arr, 0), color='r', alpha=0.2) + ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|h_L\|_2$ (median)') + ax.set_title(f'ResMLP WITHOUT terminal LayerNorm (mean ± std, n={len(no_ln_bp_h)})') + ax.legend(); ax.grid(True, alpha=0.3) + + ax = axes[1, 1] + ax.plot(eps, np.mean(bp_g_arr, 0), 'b-', label='BP', lw=2) + ax.fill_between(eps, np.mean(bp_g_arr, 0)-np.std(bp_g_arr, 0), np.mean(bp_g_arr, 0)+np.std(bp_g_arr, 0), color='b', alpha=0.2) + ax.plot(eps, np.mean(dfa_g_arr, 0), 'r-', label='DFA', lw=2) + ax.fill_between(eps, np.mean(dfa_g_arr, 0)-np.std(dfa_g_arr, 0), np.mean(dfa_g_arr, 0)+np.std(dfa_g_arr, 0), color='r', alpha=0.2) + ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)') + ax.set_title(f'ResMLP WITHOUT terminal LayerNorm (mean ± std, n={len(no_ln_bp_h)})') + ax.legend(); ax.grid(True, alpha=0.3) + + plt.suptitle('Snapshot evolution: residual stream + BP grad over training\n(top: with terminal LN — DFA explodes; bottom: no terminal LN — DFA still grows but BP grad does NOT collapse)', y=1.02) + plt.tight_layout() + plt.savefig(out_path, bbox_inches='tight', dpi=150) + print(f"Saved {out_path}") + plt.close() + + +def make_vit_figure(out_path): + fig, axes = plt.subplots(1, 2, figsize=(11, 4)) + + runs = sorted([ + f for f in os.listdir('results/snapshot_vit_v1') + if f.startswith('snapshot_vit_s') and f.endswith('.json') + ]) + if not runs: + print("No ViT snapshot JSONs found") + return + + bp_h_list = []; bp_g_list = []; dfa_h_list = []; dfa_g_list = [] + eps = None + for r in runs: + path = f'results/snapshot_vit_v1/{r}' + bp = load_log(path, 'bp_log') + dfa = load_log(path, 'dfa_log') + if bp is None or dfa is None: continue + e_bp, h_bp = trajectory(bp, 'h_L'); _, g_bp = trajectory(bp, 'g_2') + e_dfa, h_dfa = trajectory(dfa, 'h_L'); _, g_dfa = trajectory(dfa, 'g_2') + bp_h_list.append(h_bp); bp_g_list.append(g_bp) + dfa_h_list.append(h_dfa); dfa_g_list.append(g_dfa) + eps = e_bp + + bp_h_arr = np.array(bp_h_list); bp_g_arr = np.array(bp_g_list) + dfa_h_arr = np.array(dfa_h_list); dfa_g_arr = np.array(dfa_g_list) + + ax = axes[0] + ax.plot(eps, np.mean(bp_h_arr, 0), 'b-', label='BP', lw=2) + if len(bp_h_list) > 1: + ax.fill_between(eps, np.mean(bp_h_arr, 0)-np.std(bp_h_arr, 0), np.mean(bp_h_arr, 0)+np.std(bp_h_arr, 0), color='b', alpha=0.2) + ax.plot(eps, np.mean(dfa_h_arr, 0), 'r-', label='DFA', lw=2) + if len(dfa_h_list) > 1: + ax.fill_between(eps, np.mean(dfa_h_arr, 0)-np.std(dfa_h_arr, 0), np.mean(dfa_h_arr, 0)+np.std(dfa_h_arr, 0), color='r', alpha=0.2) + ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|h_L^{cls}\|_2$ (median)') + ax.set_title(f'ViT-Mini, terminal LayerNorm (n={len(bp_h_list)})') + ax.legend(); ax.grid(True, alpha=0.3) + + ax = axes[1] + ax.plot(eps, np.mean(bp_g_arr, 0), 'b-', label='BP', lw=2) + if len(bp_g_list) > 1: + ax.fill_between(eps, np.mean(bp_g_arr, 0)-np.std(bp_g_arr, 0), np.mean(bp_g_arr, 0)+np.std(bp_g_arr, 0), color='b', alpha=0.2) + ax.plot(eps, np.mean(dfa_g_arr, 0), 'r-', label='DFA', lw=2) + if len(dfa_g_list) > 1: + ax.fill_between(eps, np.mean(dfa_g_arr, 0)-np.std(dfa_g_arr, 0), np.mean(dfa_g_arr, 0)+np.std(dfa_g_arr, 0), color='r', alpha=0.2) + ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)') + ax.set_title(f'ViT-Mini, terminal LayerNorm (n={len(bp_g_list)})') + ax.legend(); ax.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(out_path, bbox_inches='tight', dpi=150) + print(f"Saved {out_path}") + plt.close() + + +if __name__ == '__main__': + os.makedirs('results/figures', exist_ok=True) + make_resmlp_figure('results/figures/figure_snapshot_resmlp.pdf') + make_vit_figure('results/figures/figure_snapshot_vit.pdf') |
