diff options
Diffstat (limited to 'paper/figures/render_fig_d512L2_panelA.py')
| -rw-r--r-- | paper/figures/render_fig_d512L2_panelA.py | 92 |
1 files changed, 92 insertions, 0 deletions
diff --git a/paper/figures/render_fig_d512L2_panelA.py b/paper/figures/render_fig_d512L2_panelA.py new file mode 100644 index 0000000..b8fabf8 --- /dev/null +++ b/paper/figures/render_fig_d512L2_panelA.py @@ -0,0 +1,92 @@ +"""Panel A style scatter for d=512 L=2 qualifying seeds, with frozen baseline line.""" +import os +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import numpy as np + +REPO_ROOT = "/home/yurenh2/fa" + +# Per-seed data for qualifying seeds 1, 2, 5 +per_seed = { + "BP": [{"acc": 0.6061, "gamma": 1.0}, + {"acc": 0.6076, "gamma": 1.0}, + {"acc": 0.6065, "gamma": 1.0}], + "FA": [{"acc": 0.3471, "gamma": 0.4840}, + {"acc": 0.3464, "gamma": 0.4721}, + {"acc": 0.3410, "gamma": 0.4924}], + "DFA": [{"acc": 0.2978, "gamma": 0.2062}, + {"acc": 0.2968, "gamma": 0.1786}, + {"acc": 0.2963, "gamma": 0.1940}], +} +FROZEN = 0.349 + +COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"} +MARKERS = {"BP": "o", "FA": "s", "DFA": "D"} + +plt.rcParams.update({ + "font.size": 9, "axes.labelsize": 10, "axes.titlesize": 10, + "xtick.labelsize": 8, "ytick.labelsize": 8, "font.family": "serif", +}) + +fig, ax = plt.subplots(figsize=(4.0, 3.5)) + +# Axes swapped: x = Γ (cosine), y = accuracy +x_lim = (-0.22, 1.12) +y_lim = (-0.02, 0.72) + +# Four quadrants: boundaries at x=0 (cos=0) and y=0.10 (chance) +ax.fill_between([0, x_lim[1]], 0.10, y_lim[1], color="#c8e6c9", alpha=0.5, zorder=0) +ax.fill_between([x_lim[0], 0], y_lim[0], 0.10, color="#ffcdd2", alpha=0.5, zorder=0) +ax.fill_between([x_lim[0], 0], 0.10, y_lim[1], color="#f5f5f5", alpha=0.6, zorder=0) +ax.fill_between([0, x_lim[1]], y_lim[0], 0.10, color="#f5f5f5", alpha=0.6, zorder=0) + +ax.axvline(0, color="gray", lw=0.6, ls="--", zorder=1) +ax.axhline(0.10, color="gray", lw=0.6, ls="--", zorder=1) + +# Frozen baseline horizontal line (acc = FROZEN) +ax.axhline(FROZEN, color="#555", lw=1.2, ls=":", zorder=2) +ax.text(1.05, FROZEN + 0.01, f"frozen baseline ({FROZEN:.3f})", fontsize=7, + color="#555", ha="right", va="bottom") + +# Quadrant labels +ax.text(-0.11, 0.45, "cos < 0,\nacc > chance", fontsize=6.5, color="#888", + ha="center", va="center", style="italic", rotation=90) +ax.text(0.55, 0.04, "cos > 0,\nacc < chance", fontsize=6.5, color="#888", + ha="center", va="center", style="italic") +ax.text(0.55, 0.45, '"looks like\n learning"', fontsize=7, color="#388e3c", + ha="center", va="center", fontweight="bold") +ax.text(-0.11, 0.04, "neither", fontsize=6.5, color="#c62828", + ha="center", va="center", style="italic") + +# Plot all 3 seeds per method: x=gamma, y=acc +for method in ["BP", "FA", "DFA"]: + seeds = per_seed[method] + gammas = [s["gamma"] for s in seeds] + accs = [s["acc"] for s in seeds] + ax.scatter(gammas, accs, c=COLORS[method], marker=MARKERS[method], + s=60, zorder=5, edgecolors="k", linewidths=0.4, label=method) + +# Labels — annotate near the centroid of each cluster +for method, offsets in [("BP", (-8, 8)), ("FA", (8, -10)), ("DFA", (8, -10))]: + seeds = per_seed[method] + cx = np.mean([s["gamma"] for s in seeds]) + cy = np.mean([s["acc"] for s in seeds]) + ax.annotate(method, (cx, cy), + xytext=offsets, textcoords="offset points", fontsize=9, fontweight="bold", + color=COLORS[method]) + +ax.set_xlabel("Aggregate $\\Gamma$ (cosine)") +ax.set_ylabel("Test accuracy") +ax.set_xlim(*x_lim) +ax.set_ylim(*y_lim) +# No title — user will add caption externally + +ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":") +ax.set_axisbelow(True) + +out = os.path.join(REPO_ROOT, "paper/figures/fig_d512L2_panelA.pdf") +fig.savefig(out, bbox_inches="tight", dpi=300) +fig.savefig(out.replace(".pdf", ".png"), bbox_inches="tight", dpi=200) +print(f"Saved: {out}") |
