summaryrefslogtreecommitdiff
path: root/experiments/snapshot_compare_outln.py
blob: 9b0ac7c83fd8025708ad13c7d37f4d5c322bb1a8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Compare snapshot evolution JSONs across with-out_ln vs no-out_ln conditions
and across seeds. Produces summary tables for the P4 figure.

Usage:
    python experiments/snapshot_compare_outln.py
"""
import os, sys, json, glob
import numpy as np


def load(path):
    if not os.path.exists(path):
        return None
    with open(path) as f:
        return json.load(f)


def field(d, key, layer=2):
    """Extract a per-epoch list of values for a given metric/layer."""
    if d is None:
        return None
    log = d['bp_log'] if 'bp' in key else d['dfa_log'] if 'dfa' in key else None
    metric = key.replace('bp_', '').replace('dfa_', '')
    if metric == 'h_L_norm':
        return [r['hidden_norms'][-1] for r in log]
    if metric == 'h_L2_norm':
        return [r['hidden_norms'][2] if len(r['hidden_norms']) > 2 else None for r in log]
    if metric == 'g_l2':
        key_in_log = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in log[0] else 'bp_grad_norms_per_sample_med'
        return [r[key_in_log][layer] for r in log]
    if metric == 'acc':
        return [r['acc_eval'] for r in log]
    if metric == 'gamma_dfa':
        return [r.get('gamma_dfa', float('nan')) for r in log]
    return None


def summary_row(d, label):
    """Print a summary row for the comparison table."""
    if d is None:
        print(f"{label:35s}  MISSING")
        return
    bp = d['bp_log']
    dfa = d['dfa_log']
    bp_eps = [r['epoch'] for r in bp]
    dfa_eps = [r['epoch'] for r in dfa]
    bp_h_L_init = bp[0]['hidden_norms'][-1]
    bp_h_L_final = bp[-1]['hidden_norms'][-1]
    dfa_h_L_init = dfa[0]['hidden_norms'][-1]
    dfa_h_L_final = dfa[-1]['hidden_norms'][-1]
    bp_g_key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in bp[0] else 'bp_grad_norms_per_sample_med'
    bp_g2_init = bp[0][bp_g_key][2]
    bp_g2_final = bp[-1][bp_g_key][2]
    dfa_g2_init = dfa[0][bp_g_key][2]
    dfa_g2_final = dfa[-1][bp_g_key][2]
    bp_acc = bp[-1]['acc_eval']
    dfa_acc = dfa[-1]['acc_eval']
    bp_growth = bp_h_L_final / max(bp_h_L_init, 1e-12)
    dfa_growth = dfa_h_L_final / max(dfa_h_L_init, 1e-12)
    bp_g_change = bp_g2_final / max(bp_g2_init, 1e-30)
    dfa_g_change = dfa_g2_final / max(dfa_g2_init, 1e-30)

    print(f"{label:35s}  BP_acc={bp_acc:.3f} DFA_acc={dfa_acc:.3f}  "
          f"BP_||h_L||: {bp_h_L_init:.1e}→{bp_h_L_final:.1e} (×{bp_growth:.1e})  "
          f"DFA_||h_L||: {dfa_h_L_init:.1e}→{dfa_h_L_final:.1e} (×{dfa_growth:.1e})  "
          f"BP_||g_2||: {bp_g2_init:.1e}→{bp_g2_final:.1e}  "
          f"DFA_||g_2||: {dfa_g2_init:.1e}→{dfa_g2_final:.1e}")


def main():
    print("=" * 130)
    print("SNAPSHOT EVOLUTION COMPARISON: with-out_ln vs no-out_ln vs synthetic")
    print("=" * 130)

    runs = [
        ('with-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_evolution_v2/snapshot_evolution_s42.json'),
        ('no-out_ln s42 (ResMLP CIFAR)',   'results/snapshot_no_outln_v1/snapshot_noLN_s42.json'),
        ('no-out_ln s123 (ResMLP CIFAR)',  'results/snapshot_no_outln_v1/snapshot_noLN_s123.json'),
        ('no-out_ln s456 (ResMLP CIFAR)',  'results/snapshot_no_outln_v1/snapshot_noLN_s456.json'),
        ('synthetic α=1 s42 (StudentNet)', 'results/snapshot_synth_v1/snapshot_synth_a1.0_L4_s42.json'),
    ]
    for label, path in runs:
        d = load(path)
        summary_row(d, label)

    print()
    print("Legend: ||h_L|| = median per-sample L2 norm of final hidden state; ||g_2|| = median per-sample L2 norm of BP gradient at h_2.")
    print("All norms use .norm(dim=-1), correct.")


if __name__ == '__main__':
    main()