diff options
Diffstat (limited to 'experiments/snapshot_compare_outln.py')
| -rw-r--r-- | experiments/snapshot_compare_outln.py | 93 |
1 files changed, 93 insertions, 0 deletions
diff --git a/experiments/snapshot_compare_outln.py b/experiments/snapshot_compare_outln.py new file mode 100644 index 0000000..9b0ac7c --- /dev/null +++ b/experiments/snapshot_compare_outln.py @@ -0,0 +1,93 @@ +""" +Compare snapshot evolution JSONs across with-out_ln vs no-out_ln conditions +and across seeds. Produces summary tables for the P4 figure. + +Usage: + python experiments/snapshot_compare_outln.py +""" +import os, sys, json, glob +import numpy as np + + +def load(path): + if not os.path.exists(path): + return None + with open(path) as f: + return json.load(f) + + +def field(d, key, layer=2): + """Extract a per-epoch list of values for a given metric/layer.""" + if d is None: + return None + log = d['bp_log'] if 'bp' in key else d['dfa_log'] if 'dfa' in key else None + metric = key.replace('bp_', '').replace('dfa_', '') + if metric == 'h_L_norm': + return [r['hidden_norms'][-1] for r in log] + if metric == 'h_L2_norm': + return [r['hidden_norms'][2] if len(r['hidden_norms']) > 2 else None for r in log] + if metric == 'g_l2': + key_in_log = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in log[0] else 'bp_grad_norms_per_sample_med' + return [r[key_in_log][layer] for r in log] + if metric == 'acc': + return [r['acc_eval'] for r in log] + if metric == 'gamma_dfa': + return [r.get('gamma_dfa', float('nan')) for r in log] + return None + + +def summary_row(d, label): + """Print a summary row for the comparison table.""" + if d is None: + print(f"{label:35s} MISSING") + return + bp = d['bp_log'] + dfa = d['dfa_log'] + bp_eps = [r['epoch'] for r in bp] + dfa_eps = [r['epoch'] for r in dfa] + bp_h_L_init = bp[0]['hidden_norms'][-1] + bp_h_L_final = bp[-1]['hidden_norms'][-1] + dfa_h_L_init = dfa[0]['hidden_norms'][-1] + dfa_h_L_final = dfa[-1]['hidden_norms'][-1] + bp_g_key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in bp[0] else 'bp_grad_norms_per_sample_med' + bp_g2_init = bp[0][bp_g_key][2] + bp_g2_final = bp[-1][bp_g_key][2] + dfa_g2_init = dfa[0][bp_g_key][2] + dfa_g2_final = dfa[-1][bp_g_key][2] + bp_acc = bp[-1]['acc_eval'] + dfa_acc = dfa[-1]['acc_eval'] + bp_growth = bp_h_L_final / max(bp_h_L_init, 1e-12) + dfa_growth = dfa_h_L_final / max(dfa_h_L_init, 1e-12) + bp_g_change = bp_g2_final / max(bp_g2_init, 1e-30) + dfa_g_change = dfa_g2_final / max(dfa_g2_init, 1e-30) + + print(f"{label:35s} BP_acc={bp_acc:.3f} DFA_acc={dfa_acc:.3f} " + f"BP_||h_L||: {bp_h_L_init:.1e}→{bp_h_L_final:.1e} (×{bp_growth:.1e}) " + f"DFA_||h_L||: {dfa_h_L_init:.1e}→{dfa_h_L_final:.1e} (×{dfa_growth:.1e}) " + f"BP_||g_2||: {bp_g2_init:.1e}→{bp_g2_final:.1e} " + f"DFA_||g_2||: {dfa_g2_init:.1e}→{dfa_g2_final:.1e}") + + +def main(): + print("=" * 130) + print("SNAPSHOT EVOLUTION COMPARISON: with-out_ln vs no-out_ln vs synthetic") + print("=" * 130) + + runs = [ + ('with-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_evolution_v2/snapshot_evolution_s42.json'), + ('no-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s42.json'), + ('no-out_ln s123 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s123.json'), + ('no-out_ln s456 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s456.json'), + ('synthetic α=1 s42 (StudentNet)', 'results/snapshot_synth_v1/snapshot_synth_a1.0_L4_s42.json'), + ] + for label, path in runs: + d = load(path) + summary_row(d, label) + + print() + print("Legend: ||h_L|| = median per-sample L2 norm of final hidden state; ||g_2|| = median per-sample L2 norm of BP gradient at h_2.") + print("All norms use .norm(dim=-1), correct.") + + +if __name__ == '__main__': + main() |
