diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:14:02 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:14:02 -0500 |
| commit | 33a84534418c0459dc3988bdd53df09dcd3ab676 (patch) | |
| tree | 09fe1c5f73306e7966b21819a14d6b73c740170e /protocol | |
| parent | 0886902ab7decd7702c9764e8fb2c3e10a528e45 (diff) | |
Add §3 cross-architecture temporal evolution figure
3-column 3-row plot:
rows: ||h_L||, ||g_L||, test accuracy
cols: ResMLP (with LN) | ViT-Mini (cls + LN) | StudentNet (no LN)
BP and DFA trajectories overlaid. Floor threshold drawn on the ||g_L||
row. Visualizes the cross-architecture causal control: with-LN
architectures both show ||g_L|| collapse below 1e-7 (DFA hits the floor
within 5 epochs); without-LN architecture shows ||g_L|| stays in the
healthy regime even though ||h_L|| still grows (catastrophic vs mild).
Diffstat (limited to 'protocol')
| -rw-r--r-- | protocol/examples/plot_temporal_cross_arch.py | 127 |
1 files changed, 127 insertions, 0 deletions
diff --git a/protocol/examples/plot_temporal_cross_arch.py b/protocol/examples/plot_temporal_cross_arch.py new file mode 100644 index 0000000..ce83f30 --- /dev/null +++ b/protocol/examples/plot_temporal_cross_arch.py @@ -0,0 +1,127 @@ +""" +Plot the cross-architecture temporal validation result as a single figure +suitable for §3 of the paper. Three columns (one per architecture), three +rows: ‖h_L‖, ‖g_L‖, accuracy. BP and DFA trajectories overlaid with the +diagnostic thresholds drawn as horizontal lines. + +Data source: per-epoch snapshot logs already saved in + results/snapshot_evolution_v2/snapshot_evolution_s{seed}.json (ResMLP) + results/snapshot_vit_v1/snapshot_vit_s{seed}.json (ViT-Mini) + results/snapshot_no_outln_v1/snapshot_noLN_s{seed}.json (StudentNet) + +This script does NOT use GPU and runs in <5 seconds. + +Run: + python -m protocol.examples.plot_temporal_cross_arch --seed 42 +""" +import os +import sys +import json +import argparse + +import matplotlib +matplotlib.use("Agg") # no display needed +import matplotlib.pyplot as plt + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.insert(0, REPO_ROOT) + + +def load_snapshot(arch, seed): + if arch == "resmlp": + path = os.path.join(REPO_ROOT, f"results/snapshot_evolution_v2/snapshot_evolution_s{seed}.json") + h_key = "hidden_norms" + g_key = "bp_grad_norms_per_sample_med" + elif arch == "vit": + path = os.path.join(REPO_ROOT, f"results/snapshot_vit_v1/snapshot_vit_s{seed}.json") + h_key = "hidden_norms_cls" + g_key = "bp_grad_per_sample_l2_med" + else: # no_outln + path = os.path.join(REPO_ROOT, f"results/snapshot_no_outln_v1/snapshot_noLN_s{seed}.json") + h_key = "hidden_norms" + g_key = "bp_grad_per_sample_l2_med" + if not os.path.exists(path): + return None + with open(path) as f: + d = json.load(f) + return d, h_key, g_key + + +def trajectory(log, h_key, g_key): + epochs = [e["epoch"] for e in log] + h_L = [e[h_key][-1] for e in log] + g_L = [e[g_key][-1] for e in log] + acc = [e["acc_eval"] for e in log] + return epochs, h_L, g_L, acc + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--seed", type=int, default=42) + args = p.parse_args() + + arches = [ + ("resmlp", "ResMLP (with terminal LN)"), + ("vit", "ViT-Mini (cls + LN)"), + ("no_outln", "StudentNet (no terminal LN)"), + ] + + fig, axes = plt.subplots(3, 3, figsize=(13, 9), sharex=False) + + for col, (arch, label) in enumerate(arches): + loaded = load_snapshot(arch, args.seed) + if loaded is None: + for r in range(3): + axes[r, col].set_visible(False) + continue + d, h_key, g_key = loaded + bp_ep, bp_h, bp_g, bp_a = trajectory(d["bp_log"], h_key, g_key) + dfa_ep, dfa_h, dfa_g, dfa_a = trajectory(d["dfa_log"], h_key, g_key) + + # Row 0: ||h_L|| + ax = axes[0, col] + ax.plot(bp_ep, bp_h, label="BP", color="C0", lw=2) + ax.plot(dfa_ep, dfa_h, label="DFA", color="C3", lw=2) + ax.set_yscale("log") + ax.set_title(label, fontsize=11) + if col == 0: + ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10) + ax.legend(loc="upper left", fontsize=8) + ax.grid(True, which="both", alpha=0.3) + + # Row 1: ||g_L|| with threshold line + ax = axes[1, col] + ax.plot(bp_ep, bp_g, label="BP", color="C0", lw=2) + ax.plot(dfa_ep, dfa_g, label="DFA", color="C3", lw=2) + ax.axhline(1e-7, color="black", linestyle="--", lw=1, label=r"floor $10^{-7}$") + ax.set_yscale("log") + if col == 0: + ax.set_ylabel(r"$\|g_L\|_2$ (log)", fontsize=10) + ax.legend(loc="upper right", fontsize=8) + ax.grid(True, which="both", alpha=0.3) + + # Row 2: accuracy + ax = axes[2, col] + ax.plot(bp_ep, bp_a, label="BP", color="C0", lw=2) + ax.plot(dfa_ep, dfa_a, label="DFA", color="C3", lw=2) + if col == 0: + ax.set_ylabel("test acc", fontsize=10) + ax.set_xlabel("epoch", fontsize=10) + ax.legend(loc="lower right", fontsize=8) + ax.grid(True, alpha=0.3) + ax.set_ylim(0, 1) + + fig.suptitle( + f"Cross-architecture temporal evolution of FA diagnostics (seed {args.seed})", + fontsize=12, y=1.0 + ) + fig.tight_layout() + out_path = os.path.join(REPO_ROOT, f"results/protocol_audit/figure_cross_arch_temporal_s{args.seed}.png") + fig.savefig(out_path, dpi=140, bbox_inches="tight") + print(f"Saved {out_path}") + + +if __name__ == "__main__": + main() |
