""" Read snapshot evolution JSONs (BP vs DFA over training epochs), summarize and print comparison tables. Used for the P4 paper figure. Usage: python experiments/analyze_snapshot_evolution.py """ import sys, json import numpy as np def summarize(log, name): eps = [d['epoch'] for d in log] h_L = [d['hidden_norms'][-1] for d in log] g_l2 = [d['bp_grad_per_sample_l2_med'][2] if 'bp_grad_per_sample_l2_med' in d else d['bp_grad_norms_per_sample_med'][2] for d in log] acc = [d['acc_eval'] for d in log] print(f"\n{name} ({len(log)} epochs):") print(f" ||h_L||_2 median: ep0={h_L[0]:.3e} -> ep{eps[len(eps)//2]}={h_L[len(eps)//2]:.3e} -> ep{eps[-1]}={h_L[-1]:.3e}") print(f" ||BP grad at h_2||_2 median: ep0={g_l2[0]:.3e} -> ep{eps[len(eps)//2]}={g_l2[len(eps)//2]:.3e} -> ep{eps[-1]}={g_l2[-1]:.3e}") print(f" acc: ep0={acc[0]:.4f} -> ep{eps[-1]}={acc[-1]:.4f}") print(f" ||h_L|| growth (final/initial): {h_L[-1]/max(h_L[0], 1e-12):.3e}") print(f" ||BP_g|| change (final/initial): {g_l2[-1]/max(g_l2[0], 1e-30):.3e}") def main(): path = sys.argv[1] if len(sys.argv) > 1 else 'results/snapshot_evolution_v2/snapshot_evolution_s42.json' with open(path) as f: d = json.load(f) print(f"Loaded {path}") print(f"config: {d.get('config', {})}") print(f"depth={d.get('depth')}, d_hidden={d.get('d_hidden')}") if 'bp_log' in d: summarize(d['bp_log'], 'BP') if 'dfa_log' in d: summarize(d['dfa_log'], 'DFA') # Print compact per-epoch comparison if both available if 'bp_log' in d and 'dfa_log' in d: bp = d['bp_log'] dfa = d['dfa_log'] eps = sorted(set([x['epoch'] for x in bp]) & set([x['epoch'] for x in dfa])) sample_eps = [eps[i] for i in [0, len(eps)//4, len(eps)//2, 3*len(eps)//4, -1]] print(f"\nPer-epoch sample (BP vs DFA):") print(f"{'epoch':>6s} {'BP_||h_L||':>12s} {'DFA_||h_L||':>12s} {'BP_||g_2||':>12s} {'DFA_||g_2||':>12s} {'BP_acc':>8s} {'DFA_acc':>8s}") bp_d = {x['epoch']: x for x in bp} dfa_d = {x['epoch']: x for x in dfa} for e in sample_eps: bdat = bp_d[e] ddat = dfa_d[e] bh = bdat['hidden_norms'][-1] dh = ddat['hidden_norms'][-1] bg_key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in bdat else 'bp_grad_norms_per_sample_med' bg = bdat[bg_key][2] dg = ddat[bg_key][2] print(f"{e:>6d} {bh:>12.3e} {dh:>12.3e} {bg:>12.3e} {dg:>12.3e} {bdat['acc_eval']:>8.4f} {ddat['acc_eval']:>8.4f}") if __name__ == '__main__': main()