summaryrefslogtreecommitdiff
path: root/protocol/examples/threshold_d_sensitivity.py
blob: 065efc7f7fa4859db8dc2693522c9483b2f668b4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Sensitivity of diagnostic (d) frozen-blocks margin threshold.

Codex round 18 specifically called out the +1.4 pp margin on penalized
DFA as fragile under the choice of threshold. This script sweeps the
margin threshold from 0.5 pp to 5 pp and reports the verdict on each
condition (vanilla DFA, penalized DFA at 3 lambda values, BP).

Run:
    python -m protocol.examples.threshold_d_sensitivity
"""
import os
import sys
import json

import numpy as np

REPO_ROOT = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)


def main():
    # 3-seed mean accuracies on 4-block d=256 ResMLP CIFAR-10
    # Updated v2.32 with matched 30-epoch controls
    conditions = [
        ("BP-trainable 100ep",     0.6147, 0.004),  # protocol_audit
        ("BP-trainable 30ep",      0.585,  0.001),  # results/bp_no_penalty_30ep
        ("BP+pen 30ep lam=1e-2",   0.532,  0.006),  # results/bp_with_penalty
        ("DFA-shallow",            0.349,  0.002),  # frozen baseline
        ("DFA-vanilla 100ep",      0.306,  0.006),  # protocol_audit
        ("DFA-vanilla 30ep",       0.301,  0.005),  # results/dfa_no_penalty_30ep
        ("DFA+pen 30ep lam=1e-2",  0.360,  0.001),  # results/dfa_pen_short
    ]
    shallow_acc = 0.349

    print("=" * 80)
    print("Diagnostic (d) frozen-baseline margin threshold sensitivity")
    print("=" * 80)
    print(f"  Reference baseline (DFA-frozen-random): {shallow_acc:.4f}")
    print()
    print(f"  {'condition':<22}{'acc':>10}{'std':>10}{'margin (pp)':>14}")
    print("  " + "-" * 56)
    for name, acc, std in conditions:
        margin_pp = (acc - shallow_acc) * 100
        std_str = f"±{std:.4f}" if std is not None else "(n=1)"
        print(f"  {name:<22}{acc:>10.4f}{std_str:>10}{margin_pp:>14.2f}")
    print()

    # Sweep
    thresholds = [0.5, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0]
    print("Walk-back verdict at each threshold:")
    print(f"  {'condition':<22}", end="")
    for t in thresholds:
        print(f"{'>'+str(t)+'pp':>10}", end="")
    print()
    print("  " + "-" * (22 + 10 * len(thresholds)))
    for name, acc, std in conditions:
        margin_pp = (acc - shallow_acc) * 100
        print(f"  {name:<22}", end="")
        for t in thresholds:
            verdict = "FIRE" if margin_pp < t else "ok"
            print(f"{verdict:>10}", end="")
        print()

    print()
    print("=" * 80)
    print("INTERPRETATION")
    print("=" * 80)
    print()
    print("  - DFA-vanilla margin = -4 pp: FIRES at ALL reasonable thresholds")
    print("    (it's actively below the shallow baseline, not just close to it)")
    print()
    print("  - DFA-pen lam=1e-2 margin = +1.4 pp: knife-edge")
    print("    fires at threshold ≥ 1.5 pp")
    print("    passes at threshold ≤ 1.0 pp")
    print()
    print("    The default 2.0 pp gives a walk-back, but a reviewer setting")
    print("    1.0 pp would say the penalized DFA passes (d). The conclusion")
    print("    is sensitive to this choice.")
    print()
    print("  - DFA-pen lam=1e-3 margin = +2.3 pp: passes (d) at 2.0 pp threshold")
    print("    (slightly stronger penalty, slightly better acc — would NOT walk back)")
    print()
    print("  Round 18 lesson: the +1.4 pp finding is real but the binary verdict")
    print("  depends on a knife-edge threshold choice. The honest paper claim")
    print("  should be: 'after the penalty correction, the depth contribution is")
    print("  at most 1.4 pp above the random-blocks baseline — much smaller than")
    print("  BP's +26 pp gap over shallow', not 'the deep blocks are passive'.")
    print()
    print("  Compare to (a) 63x and (b) 24338x separation gaps from")
    print("  threshold_sensitivity.py — those diagnostics are robust; (d) is not.")


if __name__ == "__main__":
    main()