summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:14:02 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:14:02 -0500
commit33a84534418c0459dc3988bdd53df09dcd3ab676 (patch)
tree09fe1c5f73306e7966b21819a14d6b73c740170e /protocol
parent0886902ab7decd7702c9764e8fb2c3e10a528e45 (diff)
Add §3 cross-architecture temporal evolution figure
3-column 3-row plot: rows: ||h_L||, ||g_L||, test accuracy cols: ResMLP (with LN) | ViT-Mini (cls + LN) | StudentNet (no LN) BP and DFA trajectories overlaid. Floor threshold drawn on the ||g_L|| row. Visualizes the cross-architecture causal control: with-LN architectures both show ||g_L|| collapse below 1e-7 (DFA hits the floor within 5 epochs); without-LN architecture shows ||g_L|| stays in the healthy regime even though ||h_L|| still grows (catastrophic vs mild).
Diffstat (limited to 'protocol')
-rw-r--r--protocol/examples/plot_temporal_cross_arch.py127
1 files changed, 127 insertions, 0 deletions
diff --git a/protocol/examples/plot_temporal_cross_arch.py b/protocol/examples/plot_temporal_cross_arch.py
new file mode 100644
index 0000000..ce83f30
--- /dev/null
+++ b/protocol/examples/plot_temporal_cross_arch.py
@@ -0,0 +1,127 @@
+"""
+Plot the cross-architecture temporal validation result as a single figure
+suitable for §3 of the paper. Three columns (one per architecture), three
+rows: ‖h_L‖, ‖g_L‖, accuracy. BP and DFA trajectories overlaid with the
+diagnostic thresholds drawn as horizontal lines.
+
+Data source: per-epoch snapshot logs already saved in
+ results/snapshot_evolution_v2/snapshot_evolution_s{seed}.json (ResMLP)
+ results/snapshot_vit_v1/snapshot_vit_s{seed}.json (ViT-Mini)
+ results/snapshot_no_outln_v1/snapshot_noLN_s{seed}.json (StudentNet)
+
+This script does NOT use GPU and runs in <5 seconds.
+
+Run:
+ python -m protocol.examples.plot_temporal_cross_arch --seed 42
+"""
+import os
+import sys
+import json
+import argparse
+
+import matplotlib
+matplotlib.use("Agg") # no display needed
+import matplotlib.pyplot as plt
+
+REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+sys.path.insert(0, REPO_ROOT)
+
+
+def load_snapshot(arch, seed):
+ if arch == "resmlp":
+ path = os.path.join(REPO_ROOT, f"results/snapshot_evolution_v2/snapshot_evolution_s{seed}.json")
+ h_key = "hidden_norms"
+ g_key = "bp_grad_norms_per_sample_med"
+ elif arch == "vit":
+ path = os.path.join(REPO_ROOT, f"results/snapshot_vit_v1/snapshot_vit_s{seed}.json")
+ h_key = "hidden_norms_cls"
+ g_key = "bp_grad_per_sample_l2_med"
+ else: # no_outln
+ path = os.path.join(REPO_ROOT, f"results/snapshot_no_outln_v1/snapshot_noLN_s{seed}.json")
+ h_key = "hidden_norms"
+ g_key = "bp_grad_per_sample_l2_med"
+ if not os.path.exists(path):
+ return None
+ with open(path) as f:
+ d = json.load(f)
+ return d, h_key, g_key
+
+
+def trajectory(log, h_key, g_key):
+ epochs = [e["epoch"] for e in log]
+ h_L = [e[h_key][-1] for e in log]
+ g_L = [e[g_key][-1] for e in log]
+ acc = [e["acc_eval"] for e in log]
+ return epochs, h_L, g_L, acc
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument("--seed", type=int, default=42)
+ args = p.parse_args()
+
+ arches = [
+ ("resmlp", "ResMLP (with terminal LN)"),
+ ("vit", "ViT-Mini (cls + LN)"),
+ ("no_outln", "StudentNet (no terminal LN)"),
+ ]
+
+ fig, axes = plt.subplots(3, 3, figsize=(13, 9), sharex=False)
+
+ for col, (arch, label) in enumerate(arches):
+ loaded = load_snapshot(arch, args.seed)
+ if loaded is None:
+ for r in range(3):
+ axes[r, col].set_visible(False)
+ continue
+ d, h_key, g_key = loaded
+ bp_ep, bp_h, bp_g, bp_a = trajectory(d["bp_log"], h_key, g_key)
+ dfa_ep, dfa_h, dfa_g, dfa_a = trajectory(d["dfa_log"], h_key, g_key)
+
+ # Row 0: ||h_L||
+ ax = axes[0, col]
+ ax.plot(bp_ep, bp_h, label="BP", color="C0", lw=2)
+ ax.plot(dfa_ep, dfa_h, label="DFA", color="C3", lw=2)
+ ax.set_yscale("log")
+ ax.set_title(label, fontsize=11)
+ if col == 0:
+ ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10)
+ ax.legend(loc="upper left", fontsize=8)
+ ax.grid(True, which="both", alpha=0.3)
+
+ # Row 1: ||g_L|| with threshold line
+ ax = axes[1, col]
+ ax.plot(bp_ep, bp_g, label="BP", color="C0", lw=2)
+ ax.plot(dfa_ep, dfa_g, label="DFA", color="C3", lw=2)
+ ax.axhline(1e-7, color="black", linestyle="--", lw=1, label=r"floor $10^{-7}$")
+ ax.set_yscale("log")
+ if col == 0:
+ ax.set_ylabel(r"$\|g_L\|_2$ (log)", fontsize=10)
+ ax.legend(loc="upper right", fontsize=8)
+ ax.grid(True, which="both", alpha=0.3)
+
+ # Row 2: accuracy
+ ax = axes[2, col]
+ ax.plot(bp_ep, bp_a, label="BP", color="C0", lw=2)
+ ax.plot(dfa_ep, dfa_a, label="DFA", color="C3", lw=2)
+ if col == 0:
+ ax.set_ylabel("test acc", fontsize=10)
+ ax.set_xlabel("epoch", fontsize=10)
+ ax.legend(loc="lower right", fontsize=8)
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(0, 1)
+
+ fig.suptitle(
+ f"Cross-architecture temporal evolution of FA diagnostics (seed {args.seed})",
+ fontsize=12, y=1.0
+ )
+ fig.tight_layout()
+ out_path = os.path.join(REPO_ROOT, f"results/protocol_audit/figure_cross_arch_temporal_s{args.seed}.png")
+ fig.savefig(out_path, dpi=140, bbox_inches="tight")
+ print(f"Saved {out_path}")
+
+
+if __name__ == "__main__":
+ main()