summaryrefslogtreecommitdiff
path: root/experiments/figure_snapshot_evolution.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/figure_snapshot_evolution.py')
-rw-r--r--experiments/figure_snapshot_evolution.py178
1 files changed, 178 insertions, 0 deletions
diff --git a/experiments/figure_snapshot_evolution.py b/experiments/figure_snapshot_evolution.py
new file mode 100644
index 0000000..b06f417
--- /dev/null
+++ b/experiments/figure_snapshot_evolution.py
@@ -0,0 +1,178 @@
+"""
+Generate the snapshot-evolution figure(s) for the paper from existing JSONs.
+
+Produces:
+ - figure_snapshot_resmlp.pdf : ResMLP with vs without out_ln, ||h_L|| and ||g||
+ over epochs for BP and DFA
+ - figure_snapshot_vit.pdf : ViT-Mini ||h_L|| and ||g|| over epochs for BP/DFA
+
+Usage:
+ python experiments/figure_snapshot_evolution.py
+"""
+import os, sys, json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+
+
+def load_log(path, log_key):
+ if not os.path.exists(path):
+ return None
+ with open(path) as f:
+ return json.load(f).get(log_key)
+
+
+def trajectory(log, metric):
+ """Extract a per-epoch trajectory for the given metric."""
+ eps = [r['epoch'] for r in log]
+ if metric == 'h_L':
+ # last hidden norm — handles both ResMLP (hidden_norms) and ViT (hidden_norms_cls)
+ values = []
+ for r in log:
+ if 'hidden_norms_cls' in r:
+ values.append(r['hidden_norms_cls'][-1])
+ else:
+ values.append(r['hidden_norms'][-1])
+ elif metric == 'g_2':
+ values = []
+ for r in log:
+ key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in r else 'bp_grad_norms_per_sample_med'
+ values.append(r[key][2])
+ elif metric == 'acc':
+ values = [r['acc_eval'] for r in log]
+ elif metric == 'gamma_dfa':
+ values = [r.get('gamma_dfa', float('nan')) for r in log]
+ else:
+ return None, None
+ return np.array(eps), np.array(values)
+
+
+def make_resmlp_figure(out_path):
+ fig, axes = plt.subplots(2, 2, figsize=(10, 7), sharex=True)
+
+ runs = {
+ 'with_out_ln_s42': 'results/snapshot_evolution_v2/snapshot_evolution_s42.json',
+ 'no_out_ln_s42': 'results/snapshot_no_outln_v1/snapshot_noLN_s42.json',
+ 'no_out_ln_s123': 'results/snapshot_no_outln_v1/snapshot_noLN_s123.json',
+ 'no_out_ln_s456': 'results/snapshot_no_outln_v1/snapshot_noLN_s456.json',
+ }
+ runs_loaded = {k: (load_log(v, 'bp_log'), load_log(v, 'dfa_log')) for k, v in runs.items()}
+
+ # Top row: with out_ln
+ bp, dfa = runs_loaded['with_out_ln_s42']
+ ax = axes[0, 0]
+ e, v = trajectory(bp, 'h_L'); ax.plot(e, v, 'b-', label='BP', lw=2)
+ e, v = trajectory(dfa, 'h_L'); ax.plot(e, v, 'r-', label='DFA', lw=2)
+ ax.set_yscale('log'); ax.set_ylabel(r'$\|h_L\|_2$ (median)')
+ ax.set_title('ResMLP with terminal LayerNorm (s42)')
+ ax.legend(); ax.grid(True, alpha=0.3)
+
+ ax = axes[0, 1]
+ e, v = trajectory(bp, 'g_2'); ax.plot(e, v, 'b-', label='BP', lw=2)
+ e, v = trajectory(dfa, 'g_2'); ax.plot(e, v, 'r-', label='DFA', lw=2)
+ ax.set_yscale('log'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)')
+ ax.set_title('ResMLP with terminal LayerNorm (s42)')
+ ax.legend(); ax.grid(True, alpha=0.3)
+
+ # Bottom row: no out_ln, mean ± std across 3 seeds
+ no_ln_bp_h = []; no_ln_bp_g = []; no_ln_dfa_h = []; no_ln_dfa_g = []
+ for k in ['no_out_ln_s42', 'no_out_ln_s123', 'no_out_ln_s456']:
+ bp, dfa = runs_loaded[k]
+ if bp is None or dfa is None: continue
+ e_bp, h_bp = trajectory(bp, 'h_L'); _, g_bp = trajectory(bp, 'g_2')
+ e_dfa, h_dfa = trajectory(dfa, 'h_L'); _, g_dfa = trajectory(dfa, 'g_2')
+ no_ln_bp_h.append(h_bp); no_ln_bp_g.append(g_bp)
+ no_ln_dfa_h.append(h_dfa); no_ln_dfa_g.append(g_dfa)
+
+ if no_ln_bp_h:
+ eps = e_bp
+ bp_h_arr = np.array(no_ln_bp_h)
+ bp_g_arr = np.array(no_ln_bp_g)
+ dfa_h_arr = np.array(no_ln_dfa_h)
+ dfa_g_arr = np.array(no_ln_dfa_g)
+
+ ax = axes[1, 0]
+ ax.plot(eps, np.mean(bp_h_arr, 0), 'b-', label='BP', lw=2)
+ ax.fill_between(eps, np.mean(bp_h_arr, 0)-np.std(bp_h_arr, 0), np.mean(bp_h_arr, 0)+np.std(bp_h_arr, 0), color='b', alpha=0.2)
+ ax.plot(eps, np.mean(dfa_h_arr, 0), 'r-', label='DFA', lw=2)
+ ax.fill_between(eps, np.mean(dfa_h_arr, 0)-np.std(dfa_h_arr, 0), np.mean(dfa_h_arr, 0)+np.std(dfa_h_arr, 0), color='r', alpha=0.2)
+ ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|h_L\|_2$ (median)')
+ ax.set_title(f'ResMLP WITHOUT terminal LayerNorm (mean ± std, n={len(no_ln_bp_h)})')
+ ax.legend(); ax.grid(True, alpha=0.3)
+
+ ax = axes[1, 1]
+ ax.plot(eps, np.mean(bp_g_arr, 0), 'b-', label='BP', lw=2)
+ ax.fill_between(eps, np.mean(bp_g_arr, 0)-np.std(bp_g_arr, 0), np.mean(bp_g_arr, 0)+np.std(bp_g_arr, 0), color='b', alpha=0.2)
+ ax.plot(eps, np.mean(dfa_g_arr, 0), 'r-', label='DFA', lw=2)
+ ax.fill_between(eps, np.mean(dfa_g_arr, 0)-np.std(dfa_g_arr, 0), np.mean(dfa_g_arr, 0)+np.std(dfa_g_arr, 0), color='r', alpha=0.2)
+ ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)')
+ ax.set_title(f'ResMLP WITHOUT terminal LayerNorm (mean ± std, n={len(no_ln_bp_h)})')
+ ax.legend(); ax.grid(True, alpha=0.3)
+
+ plt.suptitle('Snapshot evolution: residual stream + BP grad over training\n(top: with terminal LN — DFA explodes; bottom: no terminal LN — DFA still grows but BP grad does NOT collapse)', y=1.02)
+ plt.tight_layout()
+ plt.savefig(out_path, bbox_inches='tight', dpi=150)
+ print(f"Saved {out_path}")
+ plt.close()
+
+
+def make_vit_figure(out_path):
+ fig, axes = plt.subplots(1, 2, figsize=(11, 4))
+
+ runs = sorted([
+ f for f in os.listdir('results/snapshot_vit_v1')
+ if f.startswith('snapshot_vit_s') and f.endswith('.json')
+ ])
+ if not runs:
+ print("No ViT snapshot JSONs found")
+ return
+
+ bp_h_list = []; bp_g_list = []; dfa_h_list = []; dfa_g_list = []
+ eps = None
+ for r in runs:
+ path = f'results/snapshot_vit_v1/{r}'
+ bp = load_log(path, 'bp_log')
+ dfa = load_log(path, 'dfa_log')
+ if bp is None or dfa is None: continue
+ e_bp, h_bp = trajectory(bp, 'h_L'); _, g_bp = trajectory(bp, 'g_2')
+ e_dfa, h_dfa = trajectory(dfa, 'h_L'); _, g_dfa = trajectory(dfa, 'g_2')
+ bp_h_list.append(h_bp); bp_g_list.append(g_bp)
+ dfa_h_list.append(h_dfa); dfa_g_list.append(g_dfa)
+ eps = e_bp
+
+ bp_h_arr = np.array(bp_h_list); bp_g_arr = np.array(bp_g_list)
+ dfa_h_arr = np.array(dfa_h_list); dfa_g_arr = np.array(dfa_g_list)
+
+ ax = axes[0]
+ ax.plot(eps, np.mean(bp_h_arr, 0), 'b-', label='BP', lw=2)
+ if len(bp_h_list) > 1:
+ ax.fill_between(eps, np.mean(bp_h_arr, 0)-np.std(bp_h_arr, 0), np.mean(bp_h_arr, 0)+np.std(bp_h_arr, 0), color='b', alpha=0.2)
+ ax.plot(eps, np.mean(dfa_h_arr, 0), 'r-', label='DFA', lw=2)
+ if len(dfa_h_list) > 1:
+ ax.fill_between(eps, np.mean(dfa_h_arr, 0)-np.std(dfa_h_arr, 0), np.mean(dfa_h_arr, 0)+np.std(dfa_h_arr, 0), color='r', alpha=0.2)
+ ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|h_L^{cls}\|_2$ (median)')
+ ax.set_title(f'ViT-Mini, terminal LayerNorm (n={len(bp_h_list)})')
+ ax.legend(); ax.grid(True, alpha=0.3)
+
+ ax = axes[1]
+ ax.plot(eps, np.mean(bp_g_arr, 0), 'b-', label='BP', lw=2)
+ if len(bp_g_list) > 1:
+ ax.fill_between(eps, np.mean(bp_g_arr, 0)-np.std(bp_g_arr, 0), np.mean(bp_g_arr, 0)+np.std(bp_g_arr, 0), color='b', alpha=0.2)
+ ax.plot(eps, np.mean(dfa_g_arr, 0), 'r-', label='DFA', lw=2)
+ if len(dfa_g_list) > 1:
+ ax.fill_between(eps, np.mean(dfa_g_arr, 0)-np.std(dfa_g_arr, 0), np.mean(dfa_g_arr, 0)+np.std(dfa_g_arr, 0), color='r', alpha=0.2)
+ ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)')
+ ax.set_title(f'ViT-Mini, terminal LayerNorm (n={len(bp_g_list)})')
+ ax.legend(); ax.grid(True, alpha=0.3)
+
+ plt.tight_layout()
+ plt.savefig(out_path, bbox_inches='tight', dpi=150)
+ print(f"Saved {out_path}")
+ plt.close()
+
+
+if __name__ == '__main__':
+ os.makedirs('results/figures', exist_ok=True)
+ make_resmlp_figure('results/figures/figure_snapshot_resmlp.pdf')
+ make_vit_figure('results/figures/figure_snapshot_vit.pdf')