summaryrefslogtreecommitdiff
path: root/experiments/analyze_snapshot_evolution.py
blob: 8b9f8afb0eff98bc826061d79405ae649a0e13e3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
Read snapshot evolution JSONs (BP vs DFA over training epochs), summarize and
print comparison tables. Used for the P4 paper figure.

Usage:
    python experiments/analyze_snapshot_evolution.py <json_path>
"""
import sys, json
import numpy as np


def summarize(log, name):
    eps = [d['epoch'] for d in log]
    h_L = [d['hidden_norms'][-1] for d in log]
    g_l2 = [d['bp_grad_per_sample_l2_med'][2] if 'bp_grad_per_sample_l2_med' in d
            else d['bp_grad_norms_per_sample_med'][2] for d in log]
    acc = [d['acc_eval'] for d in log]
    print(f"\n{name} ({len(log)} epochs):")
    print(f"  ||h_L||_2 median: ep0={h_L[0]:.3e} -> ep{eps[len(eps)//2]}={h_L[len(eps)//2]:.3e} -> ep{eps[-1]}={h_L[-1]:.3e}")
    print(f"  ||BP grad at h_2||_2 median: ep0={g_l2[0]:.3e} -> ep{eps[len(eps)//2]}={g_l2[len(eps)//2]:.3e} -> ep{eps[-1]}={g_l2[-1]:.3e}")
    print(f"  acc: ep0={acc[0]:.4f} -> ep{eps[-1]}={acc[-1]:.4f}")
    print(f"  ||h_L|| growth (final/initial): {h_L[-1]/max(h_L[0], 1e-12):.3e}")
    print(f"  ||BP_g|| change (final/initial): {g_l2[-1]/max(g_l2[0], 1e-30):.3e}")


def main():
    path = sys.argv[1] if len(sys.argv) > 1 else 'results/snapshot_evolution_v2/snapshot_evolution_s42.json'
    with open(path) as f:
        d = json.load(f)
    print(f"Loaded {path}")
    print(f"config: {d.get('config', {})}")
    print(f"depth={d.get('depth')}, d_hidden={d.get('d_hidden')}")
    if 'bp_log' in d:
        summarize(d['bp_log'], 'BP')
    if 'dfa_log' in d:
        summarize(d['dfa_log'], 'DFA')

    # Print compact per-epoch comparison if both available
    if 'bp_log' in d and 'dfa_log' in d:
        bp = d['bp_log']
        dfa = d['dfa_log']
        eps = sorted(set([x['epoch'] for x in bp]) & set([x['epoch'] for x in dfa]))
        sample_eps = [eps[i] for i in [0, len(eps)//4, len(eps)//2, 3*len(eps)//4, -1]]
        print(f"\nPer-epoch sample (BP vs DFA):")
        print(f"{'epoch':>6s} {'BP_||h_L||':>12s} {'DFA_||h_L||':>12s} {'BP_||g_2||':>12s} {'DFA_||g_2||':>12s} {'BP_acc':>8s} {'DFA_acc':>8s}")
        bp_d = {x['epoch']: x for x in bp}
        dfa_d = {x['epoch']: x for x in dfa}
        for e in sample_eps:
            bdat = bp_d[e]
            ddat = dfa_d[e]
            bh = bdat['hidden_norms'][-1]
            dh = ddat['hidden_norms'][-1]
            bg_key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in bdat else 'bp_grad_norms_per_sample_med'
            bg = bdat[bg_key][2]
            dg = ddat[bg_key][2]
            print(f"{e:>6d} {bh:>12.3e} {dh:>12.3e} {bg:>12.3e} {dg:>12.3e} {bdat['acc_eval']:>8.4f} {ddat['acc_eval']:>8.4f}")


if __name__ == '__main__':
    main()