summaryrefslogtreecommitdiff
path: root/protocol/examples/penalty_lam_3seed_summary.py
blob: 0261d301f4e78ee9f4e5b83a873951ef2aad253c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
Summarize penalty 3-seed results across lambda values.

Requires:
  - results/dfa_residual_penalty/dfa_pen_lam{0.001,0.01}_s{42,123,456}.json

Reports per-seed acc, h_L, g_2 + 3-seed mean and std for each lambda, and
explicitly checks the (d) diagnostic margin against the 2pp threshold.

Run:
    python -m protocol.examples.penalty_lam_3seed_summary
"""
import os
import sys
import json

import numpy as np

REPO_ROOT = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
PEN_DIR = os.path.join(REPO_ROOT, "results/dfa_residual_penalty")
SHALLOW_BASELINE = 0.349


def load_one(lam, seed):
    path = os.path.join(PEN_DIR, f"dfa_pen_lam{lam}_s{seed}.json")
    if not os.path.exists(path):
        return None
    with open(path) as f:
        d = json.load(f)
    final = d["log"][-1]
    return {
        "acc": d["final_test_acc"],
        "h_L": final["h_L_norm"],
        "g_2": final["g_2_norm"],
    }


def main():
    print("=" * 88)
    print("DFA + ‖f_l(h_l)‖² penalty: 3-seed summary by λ")
    print("=" * 88)

    for lam in ["0.001", "0.01"]:
        print(f"\n=== λ = {lam} ===")
        rows = []
        for seed in [42, 123, 456]:
            r = load_one(lam, seed)
            if r is None:
                print(f"  s{seed}: NOT YET AVAILABLE")
                continue
            rows.append({"seed": seed, **r})
            print(f"  s{seed}: acc={r['acc']:.4f}  ‖h_L‖={r['h_L']:.3e}  ‖g_2‖={r['g_2']:.3e}")
        if not rows:
            continue
        accs = np.array([r["acc"] for r in rows])
        h_Ls = np.array([r["h_L"] for r in rows])
        g_2s = np.array([r["g_2"] for r in rows])
        margins_pp = (accs - SHALLOW_BASELINE) * 100
        print(f"  3-seed (or partial) mean: acc={accs.mean():.4f} ± {accs.std():.4f}, "
              f"‖h_L‖={h_Ls.mean():.2e}, ‖g_2‖={g_2s.mean():.2e}")
        print(f"  margin vs DFA-shallow {SHALLOW_BASELINE}: "
              f"{margins_pp.mean():+.2f} ± {margins_pp.std():.2f} pp")
        # (d) verdict at 2pp threshold
        fires = sum(1 for m in margins_pp if m < 2.0)
        print(f"  (d) at 2pp threshold: {fires}/{len(rows)} seeds FIRE")
        if fires == 0:
            verdict = "ALL PASS — penalty rescues to clear (d)"
        elif fires == len(rows):
            verdict = "ALL FIRE — second failure mode robust to seed"
        else:
            verdict = "MIXED — verdict depends on seed"
        print(f"  Aggregate (d) reading at λ={lam}: {verdict}")

    print()
    print("=" * 88)
    print("LAMBDA × THRESHOLD CROSS-CHECK")
    print("=" * 88)
    print()
    print("If λ=1e-3 3-seed mean margin exceeds 2 pp on all 3 seeds:")
    print("  → my prior 'two failure modes via (d)' claim must be downgraded to")
    print("    'tradeoff between penalty strength and depth utilization'")
    print()
    print("If λ=1e-3 3-seed mean is ~1-2 pp (similar spread to λ=1e-2 ~1.4 pp):")
    print("  → s42 +2.3 pp was a noisy outlier; the (d) 'second failure mode' story holds")
    print()
    print("Either outcome is publishable. The point is to learn it before a reviewer does.")


if __name__ == "__main__":
    main()