diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:45:41 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:45:41 -0500 |
| commit | 3a520b203f4f0c75b37b2d5c34d461718729ea02 (patch) | |
| tree | 76cf5cfc7f2874bc7016414f1a586dee453f50d8 /protocol | |
| parent | 44614df2f4382e567b986bc6dbe5b3091072461e (diff) | |
Audit table extension to 3 seeds (s42/s123/s456)
3 seeds × 5 methods × 4 diagnostics = 60 measurements. Key reproducibility
findings:
- BP: trustworthy on all 3 seeds (acc 0.61-0.62, h_L ~200, g_L ~3-4e-4)
- EP: trustworthy on all 3 seeds (acc 0.29-0.36, h_L 3-8e3, g_L ~1e-4)
- DFA, SB, CB: walked back on all 3 seeds × all 3 of (a)/(b)/(d)
Diagnostic (c) is bimodal across seeds — confirms the prior memory finding:
- DFA s42=0.047 (noise), s123=0.436 (drift), s456=-0.005 (noise)
- SB s42=0.992 (drift), s123=0.561 (drift), s456=0.035 (noise)
- CB s42=0.352 (drift), s123=0.250 (~edge), s456=0.518 (drift)
(c) catches different methods on different seeds. (a)/(b)/(d) catch all 3
failing methods on all 3 seeds — robust binary detection.
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/examples/audit_table.py | 71 |
1 files changed, 42 insertions, 29 deletions
diff --git a/protocol/examples/audit_table.py b/protocol/examples/audit_table.py index 1a75d96..da0caa9 100644 --- a/protocol/examples/audit_table.py +++ b/protocol/examples/audit_table.py @@ -100,6 +100,11 @@ FROZEN_BASELINE_ACC = { def main(): + import argparse + p = argparse.ArgumentParser() + p.add_argument("--seeds", type=int, nargs="+", default=[42]) + args = p.parse_args() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device) @@ -107,44 +112,51 @@ def main(): methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"] rows = [] reports = {} - for method in methods: - print(f"\n### {method.upper()} (seed 42)") - model = load_model(method, 42, device) - acc = evaluate(model, device) - report = diagnose( - model=model, - eval_batches=eval_batches, - headline_acc=acc, - frozen_baseline_acc=FROZEN_BASELINE_ACC.get(method), - method_name=method.upper(), - notes="4-block d=256 ResMLP, CIFAR-10, seed 42", - ) - print(report) - reports[method] = report.to_dict() - rows.append({ - "method": method, - "acc": acc, - "h_L": report.residual_norms[-1], - "g_L": report.bp_grad_norms[-1], - "stability": report.cross_batch_stability, - "frozen_acc": report.frozen_baseline_acc, - "verdict": report.verdict, - }) + for seed in args.seeds: + for method in methods: + print(f"\n### {method.upper()} (seed {seed})") + try: + model = load_model(method, seed, device) + except FileNotFoundError as e: + print(f" SKIPPED: checkpoint not found ({e})") + continue + acc = evaluate(model, device) + report = diagnose( + model=model, + eval_batches=eval_batches, + headline_acc=acc, + frozen_baseline_acc=FROZEN_BASELINE_ACC.get(method), + method_name=method.upper(), + notes=f"4-block d=256 ResMLP, CIFAR-10, seed {seed}", + ) + print(report) + reports[f"{method}_s{seed}"] = report.to_dict() + rows.append({ + "method": method, + "seed": seed, + "acc": acc, + "h_L": report.residual_norms[-1], + "g_L": report.bp_grad_norms[-1], + "stability": report.cross_batch_stability, + "frozen_acc": report.frozen_baseline_acc, + "verdict": report.verdict, + }) # Compact summary table - print("\n\n" + "=" * 100) - print("AUDIT SUMMARY (single seed 42, 4-block d=256 ResMLP, CIFAR-10)") - print("=" * 100) + print("\n\n" + "=" * 110) + print(f"AUDIT SUMMARY (seeds={args.seeds}, 4-block d=256 ResMLP, CIFAR-10)") + print("=" * 110) header = ( - f"{'method':<16}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}" + f"{'method':<16}{'seed':>6}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}" f"{'stab(L/2)':>12}{'frozen':>10} verdict" ) print(header) - print("-" * 100) + print("-" * 110) for r in rows: frozen = "n/a" if r["frozen_acc"] is None else f"{r['frozen_acc']:.4f}" print( f"{r['method']:<16}" + f"{r['seed']:>6}" f"{r['acc']:>8.4f}" f"{r['h_L']:>14.3e}" f"{r['g_L']:>14.3e}" @@ -152,7 +164,8 @@ def main(): f"{frozen:>10} {r['verdict']}" ) - out_path = os.path.join(OUT_DIR, "audit_table_s42.json") + seeds_tag = "_".join(f"s{s}" for s in args.seeds) + out_path = os.path.join(OUT_DIR, f"audit_table_{seeds_tag}.json") with open(out_path, "w") as f: json.dump({"reports": reports, "summary": rows}, f, indent=2) print(f"\nSaved {out_path}") |
