""" Compare snapshot evolution JSONs across with-out_ln vs no-out_ln conditions and across seeds. Produces summary tables for the P4 figure. Usage: python experiments/snapshot_compare_outln.py """ import os, sys, json, glob import numpy as np def load(path): if not os.path.exists(path): return None with open(path) as f: return json.load(f) def field(d, key, layer=2): """Extract a per-epoch list of values for a given metric/layer.""" if d is None: return None log = d['bp_log'] if 'bp' in key else d['dfa_log'] if 'dfa' in key else None metric = key.replace('bp_', '').replace('dfa_', '') if metric == 'h_L_norm': return [r['hidden_norms'][-1] for r in log] if metric == 'h_L2_norm': return [r['hidden_norms'][2] if len(r['hidden_norms']) > 2 else None for r in log] if metric == 'g_l2': key_in_log = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in log[0] else 'bp_grad_norms_per_sample_med' return [r[key_in_log][layer] for r in log] if metric == 'acc': return [r['acc_eval'] for r in log] if metric == 'gamma_dfa': return [r.get('gamma_dfa', float('nan')) for r in log] return None def summary_row(d, label): """Print a summary row for the comparison table.""" if d is None: print(f"{label:35s} MISSING") return bp = d['bp_log'] dfa = d['dfa_log'] bp_eps = [r['epoch'] for r in bp] dfa_eps = [r['epoch'] for r in dfa] bp_h_L_init = bp[0]['hidden_norms'][-1] bp_h_L_final = bp[-1]['hidden_norms'][-1] dfa_h_L_init = dfa[0]['hidden_norms'][-1] dfa_h_L_final = dfa[-1]['hidden_norms'][-1] bp_g_key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in bp[0] else 'bp_grad_norms_per_sample_med' bp_g2_init = bp[0][bp_g_key][2] bp_g2_final = bp[-1][bp_g_key][2] dfa_g2_init = dfa[0][bp_g_key][2] dfa_g2_final = dfa[-1][bp_g_key][2] bp_acc = bp[-1]['acc_eval'] dfa_acc = dfa[-1]['acc_eval'] bp_growth = bp_h_L_final / max(bp_h_L_init, 1e-12) dfa_growth = dfa_h_L_final / max(dfa_h_L_init, 1e-12) bp_g_change = bp_g2_final / max(bp_g2_init, 1e-30) dfa_g_change = dfa_g2_final / max(dfa_g2_init, 1e-30) print(f"{label:35s} BP_acc={bp_acc:.3f} DFA_acc={dfa_acc:.3f} " f"BP_||h_L||: {bp_h_L_init:.1e}→{bp_h_L_final:.1e} (×{bp_growth:.1e}) " f"DFA_||h_L||: {dfa_h_L_init:.1e}→{dfa_h_L_final:.1e} (×{dfa_growth:.1e}) " f"BP_||g_2||: {bp_g2_init:.1e}→{bp_g2_final:.1e} " f"DFA_||g_2||: {dfa_g2_init:.1e}→{dfa_g2_final:.1e}") def main(): print("=" * 130) print("SNAPSHOT EVOLUTION COMPARISON: with-out_ln vs no-out_ln vs synthetic") print("=" * 130) runs = [ ('with-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_evolution_v2/snapshot_evolution_s42.json'), ('no-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s42.json'), ('no-out_ln s123 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s123.json'), ('no-out_ln s456 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s456.json'), ('synthetic α=1 s42 (StudentNet)', 'results/snapshot_synth_v1/snapshot_synth_a1.0_L4_s42.json'), ] for label, path in runs: d = load(path) summary_row(d, label) print() print("Legend: ||h_L|| = median per-sample L2 norm of final hidden state; ||g_2|| = median per-sample L2 norm of BP gradient at h_2.") print("All norms use .norm(dim=-1), correct.") if __name__ == '__main__': main()