summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 00:07:39 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 00:07:39 -0500
commit1f705408da9eb9ff0fcb6f2269dadb2ebf71a0f1 (patch)
treee759097a6c575e0c03f0e5df004c595d8caa2f57 /protocol
parent76edf529be1b8aa8813ce380d104eaa424a3dc1d (diff)
Add penalty lambda 3-seed summary script + checkpoint save in penalty test
- New script: protocol/examples/penalty_lam_3seed_summary.py Loads existing penalty JSON files for lam=1e-3 and lam=1e-2 across seeds, computes 3-seed mean margin vs DFA-shallow baseline, and explicitly checks the (d) verdict at 2pp threshold per seed and in aggregate. Reports MIXED if seeds disagree. Current result: lam=1e-2 has 3 seeds (margin +1.38 ± 0.05 pp, all FIRE), lam=1e-3 has 1 seed (+2.31 pp, PASSES). Awaiting s123/s456 for lam=1e-3. - experiments/dfa_residual_penalty_test.py: now saves model checkpoint + Bs alongside JSON log so post-hoc protocol can be applied without re-running. Closes the pitfall #6.5 self-disclosure (auxiliary nets must be saved for post-hoc Gamma to be reconstructible).
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/penalty_lam_3seed_summary.py92
1 files changed, 92 insertions, 0 deletions
diff --git a/protocol/examples/penalty_lam_3seed_summary.py b/protocol/examples/penalty_lam_3seed_summary.py
new file mode 100644
index 0000000..0261d30
--- /dev/null
+++ b/protocol/examples/penalty_lam_3seed_summary.py
@@ -0,0 +1,92 @@
+"""
+Summarize penalty 3-seed results across lambda values.
+
+Requires:
+ - results/dfa_residual_penalty/dfa_pen_lam{0.001,0.01}_s{42,123,456}.json
+
+Reports per-seed acc, h_L, g_2 + 3-seed mean and std for each lambda, and
+explicitly checks the (d) diagnostic margin against the 2pp threshold.
+
+Run:
+ python -m protocol.examples.penalty_lam_3seed_summary
+"""
+import os
+import sys
+import json
+
+import numpy as np
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+PEN_DIR = os.path.join(REPO_ROOT, "results/dfa_residual_penalty")
+SHALLOW_BASELINE = 0.349
+
+
+def load_one(lam, seed):
+ path = os.path.join(PEN_DIR, f"dfa_pen_lam{lam}_s{seed}.json")
+ if not os.path.exists(path):
+ return None
+ with open(path) as f:
+ d = json.load(f)
+ final = d["log"][-1]
+ return {
+ "acc": d["final_test_acc"],
+ "h_L": final["h_L_norm"],
+ "g_2": final["g_2_norm"],
+ }
+
+
+def main():
+ print("=" * 88)
+ print("DFA + ‖f_l(h_l)‖² penalty: 3-seed summary by λ")
+ print("=" * 88)
+
+ for lam in ["0.001", "0.01"]:
+ print(f"\n=== λ = {lam} ===")
+ rows = []
+ for seed in [42, 123, 456]:
+ r = load_one(lam, seed)
+ if r is None:
+ print(f" s{seed}: NOT YET AVAILABLE")
+ continue
+ rows.append({"seed": seed, **r})
+ print(f" s{seed}: acc={r['acc']:.4f} ‖h_L‖={r['h_L']:.3e} ‖g_2‖={r['g_2']:.3e}")
+ if not rows:
+ continue
+ accs = np.array([r["acc"] for r in rows])
+ h_Ls = np.array([r["h_L"] for r in rows])
+ g_2s = np.array([r["g_2"] for r in rows])
+ margins_pp = (accs - SHALLOW_BASELINE) * 100
+ print(f" 3-seed (or partial) mean: acc={accs.mean():.4f} ± {accs.std():.4f}, "
+ f"‖h_L‖={h_Ls.mean():.2e}, ‖g_2‖={g_2s.mean():.2e}")
+ print(f" margin vs DFA-shallow {SHALLOW_BASELINE}: "
+ f"{margins_pp.mean():+.2f} ± {margins_pp.std():.2f} pp")
+ # (d) verdict at 2pp threshold
+ fires = sum(1 for m in margins_pp if m < 2.0)
+ print(f" (d) at 2pp threshold: {fires}/{len(rows)} seeds FIRE")
+ if fires == 0:
+ verdict = "ALL PASS — penalty rescues to clear (d)"
+ elif fires == len(rows):
+ verdict = "ALL FIRE — second failure mode robust to seed"
+ else:
+ verdict = "MIXED — verdict depends on seed"
+ print(f" Aggregate (d) reading at λ={lam}: {verdict}")
+
+ print()
+ print("=" * 88)
+ print("LAMBDA × THRESHOLD CROSS-CHECK")
+ print("=" * 88)
+ print()
+ print("If λ=1e-3 3-seed mean margin exceeds 2 pp on all 3 seeds:")
+ print(" → my prior 'two failure modes via (d)' claim must be downgraded to")
+ print(" 'tradeoff between penalty strength and depth utilization'")
+ print()
+ print("If λ=1e-3 3-seed mean is ~1-2 pp (similar spread to λ=1e-2 ~1.4 pp):")
+ print(" → s42 +2.3 pp was a noisy outlier; the (d) 'second failure mode' story holds")
+ print()
+ print("Either outcome is publishable. The point is to learn it before a reviewer does.")
+
+
+if __name__ == "__main__":
+ main()