""" Plot the cross-architecture temporal validation result as a single figure suitable for §3 of the paper. Three columns (one per architecture), three rows: ‖h_L‖, ‖g_L‖, accuracy. BP and DFA trajectories overlaid with the diagnostic thresholds drawn as horizontal lines. Data source: per-epoch snapshot logs already saved in results/snapshot_evolution_v2/snapshot_evolution_s{seed}.json (ResMLP) results/snapshot_vit_v1/snapshot_vit_s{seed}.json (ViT-Mini) results/snapshot_no_outln_v1/snapshot_noLN_s{seed}.json (StudentNet) This script does NOT use GPU and runs in <5 seconds. Run: python -m protocol.examples.plot_temporal_cross_arch --seed 42 """ import os import sys import json import argparse import matplotlib matplotlib.use("Agg") # no display needed import matplotlib.pyplot as plt REPO_ROOT = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) sys.path.insert(0, REPO_ROOT) def load_snapshot(arch, seed): if arch == "resmlp": path = os.path.join(REPO_ROOT, f"results/snapshot_evolution_v2/snapshot_evolution_s{seed}.json") h_key = "hidden_norms" g_key = "bp_grad_norms_per_sample_med" elif arch == "vit": path = os.path.join(REPO_ROOT, f"results/snapshot_vit_v1/snapshot_vit_s{seed}.json") h_key = "hidden_norms_cls" g_key = "bp_grad_per_sample_l2_med" else: # no_outln path = os.path.join(REPO_ROOT, f"results/snapshot_no_outln_v1/snapshot_noLN_s{seed}.json") h_key = "hidden_norms" g_key = "bp_grad_per_sample_l2_med" if not os.path.exists(path): return None with open(path) as f: d = json.load(f) return d, h_key, g_key def trajectory(log, h_key, g_key): epochs = [e["epoch"] for e in log] h_L = [e[h_key][-1] for e in log] g_L = [e[g_key][-1] for e in log] acc = [e["acc_eval"] for e in log] return epochs, h_L, g_L, acc def main(): p = argparse.ArgumentParser() p.add_argument("--seed", type=int, default=42) args = p.parse_args() arches = [ ("resmlp", "ResMLP (with terminal LN)"), ("vit", "ViT-Mini (cls + LN)"), ("no_outln", "StudentNet (no terminal LN)"), ] fig, axes = plt.subplots(3, 3, figsize=(13, 9), sharex=False) for col, (arch, label) in enumerate(arches): loaded = load_snapshot(arch, args.seed) if loaded is None: for r in range(3): axes[r, col].set_visible(False) continue d, h_key, g_key = loaded bp_ep, bp_h, bp_g, bp_a = trajectory(d["bp_log"], h_key, g_key) dfa_ep, dfa_h, dfa_g, dfa_a = trajectory(d["dfa_log"], h_key, g_key) # Row 0: ||h_L|| ax = axes[0, col] ax.plot(bp_ep, bp_h, label="BP", color="C0", lw=2) ax.plot(dfa_ep, dfa_h, label="DFA", color="C3", lw=2) ax.set_yscale("log") ax.set_title(label, fontsize=11) if col == 0: ax.set_ylabel(r"$\|h_L\|_2$ (log)", fontsize=10) ax.legend(loc="upper left", fontsize=8) ax.grid(True, which="both", alpha=0.3) # Row 1: ||g_L|| with threshold line ax = axes[1, col] ax.plot(bp_ep, bp_g, label="BP", color="C0", lw=2) ax.plot(dfa_ep, dfa_g, label="DFA", color="C3", lw=2) ax.axhline(1e-7, color="black", linestyle="--", lw=1, label=r"floor $10^{-7}$") ax.set_yscale("log") if col == 0: ax.set_ylabel(r"$\|g_L\|_2$ (log)", fontsize=10) ax.legend(loc="upper right", fontsize=8) ax.grid(True, which="both", alpha=0.3) # Row 2: accuracy ax = axes[2, col] ax.plot(bp_ep, bp_a, label="BP", color="C0", lw=2) ax.plot(dfa_ep, dfa_a, label="DFA", color="C3", lw=2) if col == 0: ax.set_ylabel("test acc", fontsize=10) ax.set_xlabel("epoch", fontsize=10) ax.legend(loc="lower right", fontsize=8) ax.grid(True, alpha=0.3) ax.set_ylim(0, 1) fig.suptitle( f"Cross-architecture temporal evolution of FA diagnostics (seed {args.seed})", fontsize=12, y=1.0 ) fig.tight_layout() out_path = os.path.join(REPO_ROOT, f"results/protocol_audit/figure_cross_arch_temporal_s{args.seed}.png") fig.savefig(out_path, dpi=140, bbox_inches="tight") print(f"Saved {out_path}") if __name__ == "__main__": main()