1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
|
"""
Compare snapshot evolution JSONs across with-out_ln vs no-out_ln conditions
and across seeds. Produces summary tables for the P4 figure.
Usage:
python experiments/snapshot_compare_outln.py
"""
import os, sys, json, glob
import numpy as np
def load(path):
if not os.path.exists(path):
return None
with open(path) as f:
return json.load(f)
def field(d, key, layer=2):
"""Extract a per-epoch list of values for a given metric/layer."""
if d is None:
return None
log = d['bp_log'] if 'bp' in key else d['dfa_log'] if 'dfa' in key else None
metric = key.replace('bp_', '').replace('dfa_', '')
if metric == 'h_L_norm':
return [r['hidden_norms'][-1] for r in log]
if metric == 'h_L2_norm':
return [r['hidden_norms'][2] if len(r['hidden_norms']) > 2 else None for r in log]
if metric == 'g_l2':
key_in_log = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in log[0] else 'bp_grad_norms_per_sample_med'
return [r[key_in_log][layer] for r in log]
if metric == 'acc':
return [r['acc_eval'] for r in log]
if metric == 'gamma_dfa':
return [r.get('gamma_dfa', float('nan')) for r in log]
return None
def summary_row(d, label):
"""Print a summary row for the comparison table."""
if d is None:
print(f"{label:35s} MISSING")
return
bp = d['bp_log']
dfa = d['dfa_log']
bp_eps = [r['epoch'] for r in bp]
dfa_eps = [r['epoch'] for r in dfa]
bp_h_L_init = bp[0]['hidden_norms'][-1]
bp_h_L_final = bp[-1]['hidden_norms'][-1]
dfa_h_L_init = dfa[0]['hidden_norms'][-1]
dfa_h_L_final = dfa[-1]['hidden_norms'][-1]
bp_g_key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in bp[0] else 'bp_grad_norms_per_sample_med'
bp_g2_init = bp[0][bp_g_key][2]
bp_g2_final = bp[-1][bp_g_key][2]
dfa_g2_init = dfa[0][bp_g_key][2]
dfa_g2_final = dfa[-1][bp_g_key][2]
bp_acc = bp[-1]['acc_eval']
dfa_acc = dfa[-1]['acc_eval']
bp_growth = bp_h_L_final / max(bp_h_L_init, 1e-12)
dfa_growth = dfa_h_L_final / max(dfa_h_L_init, 1e-12)
bp_g_change = bp_g2_final / max(bp_g2_init, 1e-30)
dfa_g_change = dfa_g2_final / max(dfa_g2_init, 1e-30)
print(f"{label:35s} BP_acc={bp_acc:.3f} DFA_acc={dfa_acc:.3f} "
f"BP_||h_L||: {bp_h_L_init:.1e}→{bp_h_L_final:.1e} (×{bp_growth:.1e}) "
f"DFA_||h_L||: {dfa_h_L_init:.1e}→{dfa_h_L_final:.1e} (×{dfa_growth:.1e}) "
f"BP_||g_2||: {bp_g2_init:.1e}→{bp_g2_final:.1e} "
f"DFA_||g_2||: {dfa_g2_init:.1e}→{dfa_g2_final:.1e}")
def main():
print("=" * 130)
print("SNAPSHOT EVOLUTION COMPARISON: with-out_ln vs no-out_ln vs synthetic")
print("=" * 130)
runs = [
('with-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_evolution_v2/snapshot_evolution_s42.json'),
('no-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s42.json'),
('no-out_ln s123 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s123.json'),
('no-out_ln s456 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s456.json'),
('synthetic α=1 s42 (StudentNet)', 'results/snapshot_synth_v1/snapshot_synth_a1.0_L4_s42.json'),
]
for label, path in runs:
d = load(path)
summary_row(d, label)
print()
print("Legend: ||h_L|| = median per-sample L2 norm of final hidden state; ||g_2|| = median per-sample L2 norm of BP gradient at h_2.")
print("All norms use .norm(dim=-1), correct.")
if __name__ == '__main__':
main()
|