summaryrefslogtreecommitdiff
path: root/paper/figures/render_fig_d512L2_panelA.py
diff options
context:
space:
mode:
Diffstat (limited to 'paper/figures/render_fig_d512L2_panelA.py')
-rw-r--r--paper/figures/render_fig_d512L2_panelA.py92
1 files changed, 92 insertions, 0 deletions
diff --git a/paper/figures/render_fig_d512L2_panelA.py b/paper/figures/render_fig_d512L2_panelA.py
new file mode 100644
index 0000000..b8fabf8
--- /dev/null
+++ b/paper/figures/render_fig_d512L2_panelA.py
@@ -0,0 +1,92 @@
+"""Panel A style scatter for d=512 L=2 qualifying seeds, with frozen baseline line."""
+import os
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import matplotlib.patches as mpatches
+import numpy as np
+
+REPO_ROOT = "/home/yurenh2/fa"
+
+# Per-seed data for qualifying seeds 1, 2, 5
+per_seed = {
+ "BP": [{"acc": 0.6061, "gamma": 1.0},
+ {"acc": 0.6076, "gamma": 1.0},
+ {"acc": 0.6065, "gamma": 1.0}],
+ "FA": [{"acc": 0.3471, "gamma": 0.4840},
+ {"acc": 0.3464, "gamma": 0.4721},
+ {"acc": 0.3410, "gamma": 0.4924}],
+ "DFA": [{"acc": 0.2978, "gamma": 0.2062},
+ {"acc": 0.2968, "gamma": 0.1786},
+ {"acc": 0.2963, "gamma": 0.1940}],
+}
+FROZEN = 0.349
+
+COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"}
+MARKERS = {"BP": "o", "FA": "s", "DFA": "D"}
+
+plt.rcParams.update({
+ "font.size": 9, "axes.labelsize": 10, "axes.titlesize": 10,
+ "xtick.labelsize": 8, "ytick.labelsize": 8, "font.family": "serif",
+})
+
+fig, ax = plt.subplots(figsize=(4.0, 3.5))
+
+# Axes swapped: x = Γ (cosine), y = accuracy
+x_lim = (-0.22, 1.12)
+y_lim = (-0.02, 0.72)
+
+# Four quadrants: boundaries at x=0 (cos=0) and y=0.10 (chance)
+ax.fill_between([0, x_lim[1]], 0.10, y_lim[1], color="#c8e6c9", alpha=0.5, zorder=0)
+ax.fill_between([x_lim[0], 0], y_lim[0], 0.10, color="#ffcdd2", alpha=0.5, zorder=0)
+ax.fill_between([x_lim[0], 0], 0.10, y_lim[1], color="#f5f5f5", alpha=0.6, zorder=0)
+ax.fill_between([0, x_lim[1]], y_lim[0], 0.10, color="#f5f5f5", alpha=0.6, zorder=0)
+
+ax.axvline(0, color="gray", lw=0.6, ls="--", zorder=1)
+ax.axhline(0.10, color="gray", lw=0.6, ls="--", zorder=1)
+
+# Frozen baseline horizontal line (acc = FROZEN)
+ax.axhline(FROZEN, color="#555", lw=1.2, ls=":", zorder=2)
+ax.text(1.05, FROZEN + 0.01, f"frozen baseline ({FROZEN:.3f})", fontsize=7,
+ color="#555", ha="right", va="bottom")
+
+# Quadrant labels
+ax.text(-0.11, 0.45, "cos < 0,\nacc > chance", fontsize=6.5, color="#888",
+ ha="center", va="center", style="italic", rotation=90)
+ax.text(0.55, 0.04, "cos > 0,\nacc < chance", fontsize=6.5, color="#888",
+ ha="center", va="center", style="italic")
+ax.text(0.55, 0.45, '"looks like\n learning"', fontsize=7, color="#388e3c",
+ ha="center", va="center", fontweight="bold")
+ax.text(-0.11, 0.04, "neither", fontsize=6.5, color="#c62828",
+ ha="center", va="center", style="italic")
+
+# Plot all 3 seeds per method: x=gamma, y=acc
+for method in ["BP", "FA", "DFA"]:
+ seeds = per_seed[method]
+ gammas = [s["gamma"] for s in seeds]
+ accs = [s["acc"] for s in seeds]
+ ax.scatter(gammas, accs, c=COLORS[method], marker=MARKERS[method],
+ s=60, zorder=5, edgecolors="k", linewidths=0.4, label=method)
+
+# Labels — annotate near the centroid of each cluster
+for method, offsets in [("BP", (-8, 8)), ("FA", (8, -10)), ("DFA", (8, -10))]:
+ seeds = per_seed[method]
+ cx = np.mean([s["gamma"] for s in seeds])
+ cy = np.mean([s["acc"] for s in seeds])
+ ax.annotate(method, (cx, cy),
+ xytext=offsets, textcoords="offset points", fontsize=9, fontweight="bold",
+ color=COLORS[method])
+
+ax.set_xlabel("Aggregate $\\Gamma$ (cosine)")
+ax.set_ylabel("Test accuracy")
+ax.set_xlim(*x_lim)
+ax.set_ylim(*y_lim)
+# No title — user will add caption externally
+
+ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":")
+ax.set_axisbelow(True)
+
+out = os.path.join(REPO_ROOT, "paper/figures/fig_d512L2_panelA.pdf")
+fig.savefig(out, bbox_inches="tight", dpi=300)
+fig.savefig(out.replace(".pdf", ".png"), bbox_inches="tight", dpi=200)
+print(f"Saved: {out}")