summaryrefslogtreecommitdiff
path: root/protocol/examples/plot_temporal_cross_arch.py
blob: ce83f305f57257934ce058f4ec54fec78ea38676 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Plot the cross-architecture temporal validation result as a single figure
suitable for §3 of the paper. Three columns (one per architecture), three
rows: ‖h_L‖, ‖g_L‖, accuracy. BP and DFA trajectories overlaid with the
diagnostic thresholds drawn as horizontal lines.

Data source: per-epoch snapshot logs already saved in
  results/snapshot_evolution_v2/snapshot_evolution_s{seed}.json (ResMLP)
  results/snapshot_vit_v1/snapshot_vit_s{seed}.json                 (ViT-Mini)
  results/snapshot_no_outln_v1/snapshot_noLN_s{seed}.json           (StudentNet)

This script does NOT use GPU and runs in <5 seconds.

Run:
    python -m protocol.examples.plot_temporal_cross_arch --seed 42
"""
import os
import sys
import json
import argparse

import matplotlib
matplotlib.use("Agg")  # no display needed
import matplotlib.pyplot as plt

REPO_ROOT = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.insert(0, REPO_ROOT)


def load_snapshot(arch, seed):
    if arch == "resmlp":
        path = os.path.join(REPO_ROOT, f"results/snapshot_evolution_v2/snapshot_evolution_s{seed}.json")
        h_key = "hidden_norms"
        g_key = "bp_grad_norms_per_sample_med"
    elif arch == "vit":
        path = os.path.join(REPO_ROOT, f"results/snapshot_vit_v1/snapshot_vit_s{seed}.json")
        h_key = "hidden_norms_cls"
        g_key = "bp_grad_per_sample_l2_med"
    else:  # no_outln
        path = os.path.join(REPO_ROOT, f"results/snapshot_no_outln_v1/snapshot_noLN_s{seed}.json")
        h_key = "hidden_norms"
        g_key = "bp_grad_per_sample_l2_med"
    if not os.path.exists(path):
        return None
    with open(path) as f:
        d = json.load(f)
    return d, h_key, g_key


def trajectory(log, h_key, g_key):
    epochs = [e["epoch"] for e in log]
    h_L = [e[h_key][-1] for e in log]
    g_L = [e[g_key][-1] for e in log]
    acc = [e["acc_eval"] for e in log]
    return epochs, h_L, g_L, acc


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--seed", type=int, default=42)
    args = p.parse_args()

    arches = [
        ("resmlp", "ResMLP (with terminal LN)"),
        ("vit", "ViT-Mini (cls + LN)"),
        ("no_outln", "StudentNet (no terminal LN)"),
    ]

    fig, axes = plt.subplots(3, 3, figsize=(13, 9), sharex=False)

    for col, (arch, label) in enumerate(arches):
        loaded = load_snapshot(arch, args.seed)
        if loaded is None:
            for r in range(3):
                axes[r, col].set_visible(False)
            continue
        d, h_key, g_key = loaded
        bp_ep, bp_h, bp_g, bp_a = trajectory(d["bp_log"], h_key, g_key)
        dfa_ep, dfa_h, dfa_g, dfa_a = trajectory(d["dfa_log"], h_key, g_key)

        # Row 0: ||h_L||
        ax = axes[0, col]
        ax.plot(bp_ep, bp_h, label="BP", color="C0", lw=2)
        ax.plot(dfa_ep, dfa_h, label="DFA", color="C3", lw=2)
        ax.set_yscale("log")
        ax.set_title(label, fontsize=11)
        if col == 0:
            ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10)
        ax.legend(loc="upper left", fontsize=8)
        ax.grid(True, which="both", alpha=0.3)

        # Row 1: ||g_L|| with threshold line
        ax = axes[1, col]
        ax.plot(bp_ep, bp_g, label="BP", color="C0", lw=2)
        ax.plot(dfa_ep, dfa_g, label="DFA", color="C3", lw=2)
        ax.axhline(1e-7, color="black", linestyle="--", lw=1, label=r"floor $10^{-7}$")
        ax.set_yscale("log")
        if col == 0:
            ax.set_ylabel(r"$\|g_L\|_2$ (log)", fontsize=10)
        ax.legend(loc="upper right", fontsize=8)
        ax.grid(True, which="both", alpha=0.3)

        # Row 2: accuracy
        ax = axes[2, col]
        ax.plot(bp_ep, bp_a, label="BP", color="C0", lw=2)
        ax.plot(dfa_ep, dfa_a, label="DFA", color="C3", lw=2)
        if col == 0:
            ax.set_ylabel("test acc", fontsize=10)
        ax.set_xlabel("epoch", fontsize=10)
        ax.legend(loc="lower right", fontsize=8)
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0, 1)

    fig.suptitle(
        f"Cross-architecture temporal evolution of FA diagnostics (seed {args.seed})",
        fontsize=12, y=1.0
    )
    fig.tight_layout()
    out_path = os.path.join(REPO_ROOT, f"results/protocol_audit/figure_cross_arch_temporal_s{args.seed}.png")
    fig.savefig(out_path, dpi=140, bbox_inches="tight")
    print(f"Saved {out_path}")


if __name__ == "__main__":
    main()