summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:45:41 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 22:45:41 -0500
commit3a520b203f4f0c75b37b2d5c34d461718729ea02 (patch)
tree76cf5cfc7f2874bc7016414f1a586dee453f50d8 /protocol
parent44614df2f4382e567b986bc6dbe5b3091072461e (diff)
Audit table extension to 3 seeds (s42/s123/s456)
3 seeds × 5 methods × 4 diagnostics = 60 measurements. Key reproducibility findings: - BP: trustworthy on all 3 seeds (acc 0.61-0.62, h_L ~200, g_L ~3-4e-4) - EP: trustworthy on all 3 seeds (acc 0.29-0.36, h_L 3-8e3, g_L ~1e-4) - DFA, SB, CB: walked back on all 3 seeds × all 3 of (a)/(b)/(d) Diagnostic (c) is bimodal across seeds — confirms the prior memory finding: - DFA s42=0.047 (noise), s123=0.436 (drift), s456=-0.005 (noise) - SB s42=0.992 (drift), s123=0.561 (drift), s456=0.035 (noise) - CB s42=0.352 (drift), s123=0.250 (~edge), s456=0.518 (drift) (c) catches different methods on different seeds. (a)/(b)/(d) catch all 3 failing methods on all 3 seeds — robust binary detection.
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/audit_table.py71
1 files changed, 42 insertions, 29 deletions
diff --git a/protocol/examples/audit_table.py b/protocol/examples/audit_table.py
index 1a75d96..da0caa9 100644
--- a/protocol/examples/audit_table.py
+++ b/protocol/examples/audit_table.py
@@ -100,6 +100,11 @@ FROZEN_BASELINE_ACC = {
def main():
+ import argparse
+ p = argparse.ArgumentParser()
+ p.add_argument("--seeds", type=int, nargs="+", default=[42])
+ args = p.parse_args()
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device)
@@ -107,44 +112,51 @@ def main():
methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"]
rows = []
reports = {}
- for method in methods:
- print(f"\n### {method.upper()} (seed 42)")
- model = load_model(method, 42, device)
- acc = evaluate(model, device)
- report = diagnose(
- model=model,
- eval_batches=eval_batches,
- headline_acc=acc,
- frozen_baseline_acc=FROZEN_BASELINE_ACC.get(method),
- method_name=method.upper(),
- notes="4-block d=256 ResMLP, CIFAR-10, seed 42",
- )
- print(report)
- reports[method] = report.to_dict()
- rows.append({
- "method": method,
- "acc": acc,
- "h_L": report.residual_norms[-1],
- "g_L": report.bp_grad_norms[-1],
- "stability": report.cross_batch_stability,
- "frozen_acc": report.frozen_baseline_acc,
- "verdict": report.verdict,
- })
+ for seed in args.seeds:
+ for method in methods:
+ print(f"\n### {method.upper()} (seed {seed})")
+ try:
+ model = load_model(method, seed, device)
+ except FileNotFoundError as e:
+ print(f" SKIPPED: checkpoint not found ({e})")
+ continue
+ acc = evaluate(model, device)
+ report = diagnose(
+ model=model,
+ eval_batches=eval_batches,
+ headline_acc=acc,
+ frozen_baseline_acc=FROZEN_BASELINE_ACC.get(method),
+ method_name=method.upper(),
+ notes=f"4-block d=256 ResMLP, CIFAR-10, seed {seed}",
+ )
+ print(report)
+ reports[f"{method}_s{seed}"] = report.to_dict()
+ rows.append({
+ "method": method,
+ "seed": seed,
+ "acc": acc,
+ "h_L": report.residual_norms[-1],
+ "g_L": report.bp_grad_norms[-1],
+ "stability": report.cross_batch_stability,
+ "frozen_acc": report.frozen_baseline_acc,
+ "verdict": report.verdict,
+ })
# Compact summary table
- print("\n\n" + "=" * 100)
- print("AUDIT SUMMARY (single seed 42, 4-block d=256 ResMLP, CIFAR-10)")
- print("=" * 100)
+ print("\n\n" + "=" * 110)
+ print(f"AUDIT SUMMARY (seeds={args.seeds}, 4-block d=256 ResMLP, CIFAR-10)")
+ print("=" * 110)
header = (
- f"{'method':<16}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}"
+ f"{'method':<16}{'seed':>6}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}"
f"{'stab(L/2)':>12}{'frozen':>10} verdict"
)
print(header)
- print("-" * 100)
+ print("-" * 110)
for r in rows:
frozen = "n/a" if r["frozen_acc"] is None else f"{r['frozen_acc']:.4f}"
print(
f"{r['method']:<16}"
+ f"{r['seed']:>6}"
f"{r['acc']:>8.4f}"
f"{r['h_L']:>14.3e}"
f"{r['g_L']:>14.3e}"
@@ -152,7 +164,8 @@ def main():
f"{frozen:>10} {r['verdict']}"
)
- out_path = os.path.join(OUT_DIR, "audit_table_s42.json")
+ seeds_tag = "_".join(f"s{s}" for s in args.seeds)
+ out_path = os.path.join(OUT_DIR, f"audit_table_{seeds_tag}.json")
with open(out_path, "w") as f:
json.dump({"reports": reports, "summary": rows}, f, indent=2)
print(f"\nSaved {out_path}")