summaryrefslogtreecommitdiff
path: root/protocol/examples/plot_penalty_rescue.py
blob: 37b0fa9ad6ae4a445d11e6f5ec52bcf325997fed (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Plot the §4 penalty-rescue figure: per-epoch trajectories of the protocol's
(a) and (b) diagnostics on vanilla DFA vs penalized DFA, plus accuracy.

Shows visually that the penalty rescues both diagnostics back into the
healthy regime, while the headline accuracy gain is small (the second
failure mode persists).

Data sources:
  - vanilla DFA trajectory: results/snapshot_evolution_v2/snapshot_evolution_s42.json
  - penalized DFA (lam=1e-2): results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json
  - DFA-shallow baseline 3-seed mean (drawn as horizontal line): 0.349
  - BP-trainable 3-seed mean: 0.609

Run:
    python -m protocol.examples.plot_penalty_rescue
"""
import os
import sys
import json

import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

REPO_ROOT = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)


def load_vanilla_log():
    with open(os.path.join(REPO_ROOT, "results/snapshot_evolution_v2/snapshot_evolution_s42.json")) as f:
        d = json.load(f)
    log = d["dfa_log"]
    return [
        {
            "epoch": e["epoch"],
            "h_L": e["hidden_norms"][-1],
            "g_2": e["bp_grad_norms_per_sample_med"][2],
            "acc": e["acc_eval"],
        }
        for e in log
    ]


def load_penalty_log():
    with open(os.path.join(REPO_ROOT, "results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json")) as f:
        d = json.load(f)
    return d["log"], d["final_test_acc"]


def main():
    vanilla = load_vanilla_log()
    penalty, penalty_final_acc = load_penalty_log()

    fig, axes = plt.subplots(1, 3, figsize=(13, 4.0), gridspec_kw={"wspace": 0.35})

    # --- (1) ‖h_L‖ ---
    ax = axes[0]
    ax.plot([e["epoch"] for e in vanilla], [e["h_L"] for e in vanilla],
            label="vanilla DFA", color="C3", lw=2)
    ax.plot([e["epoch"] for e in penalty], [e["h_L_norm"] for e in penalty],
            label=r"DFA + $\lambda \|f_l\|^2$ ($\lambda=10^{-2}$)", color="C2", lw=2, marker="o", markersize=4)
    ax.axhline(200, color="C0", linestyle=":", lw=1, label=r"BP $\|h_L\| \approx 200$")
    ax.set_yscale("log")
    ax.set_xlabel("epoch", fontsize=10)
    ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10)
    ax.set_title("(a) residual stream", fontsize=11)
    ax.legend(loc="lower right", fontsize=7.5)
    ax.grid(True, which="both", alpha=0.3)

    # --- (2) ‖g_2‖ ---
    ax = axes[1]
    ax.plot([e["epoch"] for e in vanilla], [e["g_2"] for e in vanilla],
            label="vanilla DFA", color="C3", lw=2)
    ax.plot([e["epoch"] for e in penalty], [e["g_2_norm"] for e in penalty],
            label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4)
    ax.axhline(1e-7, color="k", linestyle="--", lw=1.2, label=r"floor $10^{-7}$")
    ax.set_yscale("log")
    ax.set_xlabel("epoch", fontsize=10)
    ax.set_ylabel(r"$\|g_L\|_2$ (log)", fontsize=10)
    ax.set_title("(b) BP grad at hidden layer", fontsize=11)
    ax.legend(loc="lower left", fontsize=7.5)
    ax.grid(True, which="both", alpha=0.3)

    # --- (3) accuracy ---
    ax = axes[2]
    ax.plot([e["epoch"] for e in vanilla], [e["acc"] for e in vanilla],
            label="vanilla DFA", color="C3", lw=2)
    ax.plot([e["epoch"] for e in penalty], [e["acc_eval"] for e in penalty],
            label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4)
    ax.axhline(0.349, color="k", linestyle="--", lw=1.2, label="DFA-shallow 0.349")
    ax.axhline(0.609, color="C0", linestyle=":", lw=1, label="BP-trainable 0.609")
    ax.set_xlabel("epoch", fontsize=10)
    ax.set_ylabel("test acc", fontsize=10)
    ax.set_title("(d) headline accuracy", fontsize=11)
    ax.legend(loc="lower right", fontsize=7.5)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0.05, 0.7)

    fig.suptitle(
        r"Penalty rescue (4-block d=256 ResMLP, seed 42): the $\|f_l\|^2$ penalty fixes (a) and (b),"
        "\nbut the deep blocks still fail to clear the frozen baseline (the second failure mode)",
        fontsize=11, y=1.05
    )
    fig.tight_layout()
    out_path = os.path.join(REPO_ROOT, "results/protocol_audit/figure_penalty_rescue_s42.png")
    fig.savefig(out_path, dpi=140, bbox_inches="tight")
    print(f"Saved {out_path}")


if __name__ == "__main__":
    main()