diff options
Diffstat (limited to 'experiments/analyze_snapshot_evolution.py')
| -rw-r--r-- | experiments/analyze_snapshot_evolution.py | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/experiments/analyze_snapshot_evolution.py b/experiments/analyze_snapshot_evolution.py new file mode 100644 index 0000000..8b9f8af --- /dev/null +++ b/experiments/analyze_snapshot_evolution.py @@ -0,0 +1,60 @@ +""" +Read snapshot evolution JSONs (BP vs DFA over training epochs), summarize and +print comparison tables. Used for the P4 paper figure. + +Usage: + python experiments/analyze_snapshot_evolution.py <json_path> +""" +import sys, json +import numpy as np + + +def summarize(log, name): + eps = [d['epoch'] for d in log] + h_L = [d['hidden_norms'][-1] for d in log] + g_l2 = [d['bp_grad_per_sample_l2_med'][2] if 'bp_grad_per_sample_l2_med' in d + else d['bp_grad_norms_per_sample_med'][2] for d in log] + acc = [d['acc_eval'] for d in log] + print(f"\n{name} ({len(log)} epochs):") + print(f" ||h_L||_2 median: ep0={h_L[0]:.3e} -> ep{eps[len(eps)//2]}={h_L[len(eps)//2]:.3e} -> ep{eps[-1]}={h_L[-1]:.3e}") + print(f" ||BP grad at h_2||_2 median: ep0={g_l2[0]:.3e} -> ep{eps[len(eps)//2]}={g_l2[len(eps)//2]:.3e} -> ep{eps[-1]}={g_l2[-1]:.3e}") + print(f" acc: ep0={acc[0]:.4f} -> ep{eps[-1]}={acc[-1]:.4f}") + print(f" ||h_L|| growth (final/initial): {h_L[-1]/max(h_L[0], 1e-12):.3e}") + print(f" ||BP_g|| change (final/initial): {g_l2[-1]/max(g_l2[0], 1e-30):.3e}") + + +def main(): + path = sys.argv[1] if len(sys.argv) > 1 else 'results/snapshot_evolution_v2/snapshot_evolution_s42.json' + with open(path) as f: + d = json.load(f) + print(f"Loaded {path}") + print(f"config: {d.get('config', {})}") + print(f"depth={d.get('depth')}, d_hidden={d.get('d_hidden')}") + if 'bp_log' in d: + summarize(d['bp_log'], 'BP') + if 'dfa_log' in d: + summarize(d['dfa_log'], 'DFA') + + # Print compact per-epoch comparison if both available + if 'bp_log' in d and 'dfa_log' in d: + bp = d['bp_log'] + dfa = d['dfa_log'] + eps = sorted(set([x['epoch'] for x in bp]) & set([x['epoch'] for x in dfa])) + sample_eps = [eps[i] for i in [0, len(eps)//4, len(eps)//2, 3*len(eps)//4, -1]] + print(f"\nPer-epoch sample (BP vs DFA):") + print(f"{'epoch':>6s} {'BP_||h_L||':>12s} {'DFA_||h_L||':>12s} {'BP_||g_2||':>12s} {'DFA_||g_2||':>12s} {'BP_acc':>8s} {'DFA_acc':>8s}") + bp_d = {x['epoch']: x for x in bp} + dfa_d = {x['epoch']: x for x in dfa} + for e in sample_eps: + bdat = bp_d[e] + ddat = dfa_d[e] + bh = bdat['hidden_norms'][-1] + dh = ddat['hidden_norms'][-1] + bg_key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in bdat else 'bp_grad_norms_per_sample_med' + bg = bdat[bg_key][2] + dg = ddat[bg_key][2] + print(f"{e:>6d} {bh:>12.3e} {dh:>12.3e} {bg:>12.3e} {dg:>12.3e} {bdat['acc_eval']:>8.4f} {ddat['acc_eval']:>8.4f}") + + +if __name__ == '__main__': + main() |
