summaryrefslogtreecommitdiff
path: root/experiments/snapshot_compare_outln.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/snapshot_compare_outln.py')
-rw-r--r--experiments/snapshot_compare_outln.py93
1 files changed, 93 insertions, 0 deletions
diff --git a/experiments/snapshot_compare_outln.py b/experiments/snapshot_compare_outln.py
new file mode 100644
index 0000000..9b0ac7c
--- /dev/null
+++ b/experiments/snapshot_compare_outln.py
@@ -0,0 +1,93 @@
+"""
+Compare snapshot evolution JSONs across with-out_ln vs no-out_ln conditions
+and across seeds. Produces summary tables for the P4 figure.
+
+Usage:
+ python experiments/snapshot_compare_outln.py
+"""
+import os, sys, json, glob
+import numpy as np
+
+
+def load(path):
+ if not os.path.exists(path):
+ return None
+ with open(path) as f:
+ return json.load(f)
+
+
+def field(d, key, layer=2):
+ """Extract a per-epoch list of values for a given metric/layer."""
+ if d is None:
+ return None
+ log = d['bp_log'] if 'bp' in key else d['dfa_log'] if 'dfa' in key else None
+ metric = key.replace('bp_', '').replace('dfa_', '')
+ if metric == 'h_L_norm':
+ return [r['hidden_norms'][-1] for r in log]
+ if metric == 'h_L2_norm':
+ return [r['hidden_norms'][2] if len(r['hidden_norms']) > 2 else None for r in log]
+ if metric == 'g_l2':
+ key_in_log = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in log[0] else 'bp_grad_norms_per_sample_med'
+ return [r[key_in_log][layer] for r in log]
+ if metric == 'acc':
+ return [r['acc_eval'] for r in log]
+ if metric == 'gamma_dfa':
+ return [r.get('gamma_dfa', float('nan')) for r in log]
+ return None
+
+
+def summary_row(d, label):
+ """Print a summary row for the comparison table."""
+ if d is None:
+ print(f"{label:35s} MISSING")
+ return
+ bp = d['bp_log']
+ dfa = d['dfa_log']
+ bp_eps = [r['epoch'] for r in bp]
+ dfa_eps = [r['epoch'] for r in dfa]
+ bp_h_L_init = bp[0]['hidden_norms'][-1]
+ bp_h_L_final = bp[-1]['hidden_norms'][-1]
+ dfa_h_L_init = dfa[0]['hidden_norms'][-1]
+ dfa_h_L_final = dfa[-1]['hidden_norms'][-1]
+ bp_g_key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in bp[0] else 'bp_grad_norms_per_sample_med'
+ bp_g2_init = bp[0][bp_g_key][2]
+ bp_g2_final = bp[-1][bp_g_key][2]
+ dfa_g2_init = dfa[0][bp_g_key][2]
+ dfa_g2_final = dfa[-1][bp_g_key][2]
+ bp_acc = bp[-1]['acc_eval']
+ dfa_acc = dfa[-1]['acc_eval']
+ bp_growth = bp_h_L_final / max(bp_h_L_init, 1e-12)
+ dfa_growth = dfa_h_L_final / max(dfa_h_L_init, 1e-12)
+ bp_g_change = bp_g2_final / max(bp_g2_init, 1e-30)
+ dfa_g_change = dfa_g2_final / max(dfa_g2_init, 1e-30)
+
+ print(f"{label:35s} BP_acc={bp_acc:.3f} DFA_acc={dfa_acc:.3f} "
+ f"BP_||h_L||: {bp_h_L_init:.1e}→{bp_h_L_final:.1e} (×{bp_growth:.1e}) "
+ f"DFA_||h_L||: {dfa_h_L_init:.1e}→{dfa_h_L_final:.1e} (×{dfa_growth:.1e}) "
+ f"BP_||g_2||: {bp_g2_init:.1e}→{bp_g2_final:.1e} "
+ f"DFA_||g_2||: {dfa_g2_init:.1e}→{dfa_g2_final:.1e}")
+
+
+def main():
+ print("=" * 130)
+ print("SNAPSHOT EVOLUTION COMPARISON: with-out_ln vs no-out_ln vs synthetic")
+ print("=" * 130)
+
+ runs = [
+ ('with-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_evolution_v2/snapshot_evolution_s42.json'),
+ ('no-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s42.json'),
+ ('no-out_ln s123 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s123.json'),
+ ('no-out_ln s456 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s456.json'),
+ ('synthetic α=1 s42 (StudentNet)', 'results/snapshot_synth_v1/snapshot_synth_a1.0_L4_s42.json'),
+ ]
+ for label, path in runs:
+ d = load(path)
+ summary_row(d, label)
+
+ print()
+ print("Legend: ||h_L|| = median per-sample L2 norm of final hidden state; ||g_2|| = median per-sample L2 norm of BP gradient at h_2.")
+ print("All norms use .norm(dim=-1), correct.")
+
+
+if __name__ == '__main__':
+ main()