summaryrefslogtreecommitdiff
path: root/protocol/examples/temporal_penalty_vs_vanilla.py
blob: 4c374186ccd488cff67dd4250ff6cc13bd3f93bc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Temporal trajectory of the protocol's (a) and (b) diagnostics under
DFA + residual-branch penalty, vs vanilla DFA.

Uses the existing per-epoch logs:
  - vanilla DFA: results/snapshot_evolution_v2/snapshot_evolution_s42.json
  - penalized DFA (3 lam values): results/dfa_residual_penalty/dfa_pen_lam{0.001,0.01,0.1}_s42.json

Run:
    python -m protocol.examples.temporal_penalty_vs_vanilla
"""
import os
import sys
import json

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

REPO_ROOT = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)


def load_vanilla_log():
    with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f:
        d = json.load(f)
    log = d["dfa_log"]
    return [
        {"epoch": e["epoch"],
         "h_L": e["hidden_norms"][-1],
         "g_2": e["bp_grad_norms_per_sample_med"][2],
         "acc": e["acc_eval"]}
        for e in log
    ]


def load_penalty_log(lam):
    path = os.path.join(REPO_ROOT, f"results/dfa_residual_penalty/dfa_pen_lam{lam}_s42.json")
    if not os.path.exists(path):
        return None
    with open(path) as f:
        d = json.load(f)
    return d["log"]


def main():
    vanilla = load_vanilla_log()
    p_1e3 = load_penalty_log("0.001")
    p_1e2 = load_penalty_log("0.01")
    # p_1e1 = load_penalty_log("0.1")  # was killed mid-run

    fig, axes = plt.subplots(1, 3, figsize=(13, 4.0), gridspec_kw={"wspace": 0.35})

    plots = [
        ("vanilla DFA", vanilla, "h_L", "g_2", "acc", "C3", None, "-", 2),
    ]
    if p_1e3:
        plots.append(("DFA + λ=1e-3", p_1e3, "h_L_norm", "g_2_norm", "acc_eval", "C2", "o", "-", 2))
    if p_1e2:
        plots.append(("DFA + λ=1e-2", p_1e2, "h_L_norm", "g_2_norm", "acc_eval", "C0", "s", "-", 2))

    # --- (a) ‖h_L‖ ---
    ax = axes[0]
    for label, log, hk, gk, ak, c, m, ls, lw in plots:
        ax.plot([e["epoch"] for e in log], [e[hk] for e in log],
                label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls)
    ax.axhline(200, color="gray", linestyle=":", lw=1, label="BP ~200")
    ax.set_yscale("log")
    ax.set_xlabel("epoch", fontsize=10)
    ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10)
    ax.set_title("(a) residual stream", fontsize=11)
    ax.legend(loc="lower right", fontsize=8)
    ax.grid(True, which="both", alpha=0.3)

    # --- (b) ‖g_2‖ ---
    ax = axes[1]
    for label, log, hk, gk, ak, c, m, ls, lw in plots:
        ax.plot([e["epoch"] for e in log], [e[gk] for e in log],
                label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls)
    ax.axhline(1e-7, color="k", linestyle="--", lw=1.2, label="floor 1e-7")
    ax.set_yscale("log")
    ax.set_xlabel("epoch", fontsize=10)
    ax.set_ylabel(r"$\|g_2\|_2$ (log)", fontsize=10)
    ax.set_title("(b) BP grad floor", fontsize=11)
    ax.legend(loc="lower left", fontsize=8)
    ax.grid(True, which="both", alpha=0.3)

    # --- (c) accuracy ---
    ax = axes[2]
    for label, log, hk, gk, ak, c, m, ls, lw in plots:
        ax.plot([e["epoch"] for e in log], [e[ak] for e in log],
                label=label, color=c, lw=lw, marker=m, markersize=4, linestyle=ls)
    ax.axhline(0.349, color="k", linestyle="--", lw=1.0, label="DFA-shallow 0.349")
    ax.axhline(0.371, color="purple", linestyle=":", lw=1.0, label="2pp threshold 0.371")
    ax.set_xlabel("epoch", fontsize=10)
    ax.set_ylabel("test acc", fontsize=10)
    ax.set_title("acc + frozen baseline + (d) threshold", fontsize=11)
    ax.legend(loc="lower right", fontsize=7.5)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0.05, 0.55)

    fig.suptitle(
        "Penalty rescue: (a) and (b) cleanly fix; (d) verdict depends on λ choice",
        fontsize=11, y=1.05
    )
    fig.tight_layout()
    out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_penalty_lambda_sweep.png")
    fig.savefig(out_path, dpi=140, bbox_inches="tight")
    print(f"Saved {out_path}")


if __name__ == "__main__":
    main()