"""Panel A style scatter for d=512 L=2 qualifying seeds, with frozen baseline line.""" import os import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches import numpy as np REPO_ROOT = "/home/yurenh2/fa" # Per-seed data for qualifying seeds 1, 2, 5 per_seed = { "BP": [{"acc": 0.6061, "gamma": 1.0}, {"acc": 0.6076, "gamma": 1.0}, {"acc": 0.6065, "gamma": 1.0}], "FA": [{"acc": 0.3471, "gamma": 0.4840}, {"acc": 0.3464, "gamma": 0.4721}, {"acc": 0.3410, "gamma": 0.4924}], "DFA": [{"acc": 0.2978, "gamma": 0.2062}, {"acc": 0.2968, "gamma": 0.1786}, {"acc": 0.2963, "gamma": 0.1940}], } FROZEN = 0.349 COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"} MARKERS = {"BP": "o", "FA": "s", "DFA": "D"} plt.rcParams.update({ "font.size": 9, "axes.labelsize": 10, "axes.titlesize": 10, "xtick.labelsize": 8, "ytick.labelsize": 8, "font.family": "serif", }) fig, ax = plt.subplots(figsize=(4.0, 3.5)) # Axes swapped: x = Γ (cosine), y = accuracy x_lim = (-0.22, 1.12) y_lim = (-0.02, 0.72) # Four quadrants: boundaries at x=0 (cos=0) and y=0.10 (chance) ax.fill_between([0, x_lim[1]], 0.10, y_lim[1], color="#c8e6c9", alpha=0.5, zorder=0) ax.fill_between([x_lim[0], 0], y_lim[0], 0.10, color="#ffcdd2", alpha=0.5, zorder=0) ax.fill_between([x_lim[0], 0], 0.10, y_lim[1], color="#f5f5f5", alpha=0.6, zorder=0) ax.fill_between([0, x_lim[1]], y_lim[0], 0.10, color="#f5f5f5", alpha=0.6, zorder=0) ax.axvline(0, color="gray", lw=0.6, ls="--", zorder=1) ax.axhline(0.10, color="gray", lw=0.6, ls="--", zorder=1) # Frozen baseline horizontal line (acc = FROZEN) ax.axhline(FROZEN, color="#555", lw=1.2, ls=":", zorder=2) ax.text(1.05, FROZEN + 0.01, f"frozen baseline ({FROZEN:.3f})", fontsize=7, color="#555", ha="right", va="bottom") # Quadrant labels ax.text(-0.11, 0.45, "cos < 0,\nacc > chance", fontsize=6.5, color="#888", ha="center", va="center", style="italic", rotation=90) ax.text(0.55, 0.04, "cos > 0,\nacc < chance", fontsize=6.5, color="#888", ha="center", va="center", style="italic") ax.text(0.55, 0.45, '"looks like\n learning"', fontsize=7, color="#388e3c", ha="center", va="center", fontweight="bold") ax.text(-0.11, 0.04, "neither", fontsize=6.5, color="#c62828", ha="center", va="center", style="italic") # Plot all 3 seeds per method: x=gamma, y=acc for method in ["BP", "FA", "DFA"]: seeds = per_seed[method] gammas = [s["gamma"] for s in seeds] accs = [s["acc"] for s in seeds] ax.scatter(gammas, accs, c=COLORS[method], marker=MARKERS[method], s=60, zorder=5, edgecolors="k", linewidths=0.4, label=method) # Labels — annotate near the centroid of each cluster for method, offsets in [("BP", (-8, 8)), ("FA", (8, -10)), ("DFA", (8, -10))]: seeds = per_seed[method] cx = np.mean([s["gamma"] for s in seeds]) cy = np.mean([s["acc"] for s in seeds]) ax.annotate(method, (cx, cy), xytext=offsets, textcoords="offset points", fontsize=9, fontweight="bold", color=COLORS[method]) ax.set_xlabel("Aggregate $\\Gamma$ (cosine)") ax.set_ylabel("Test accuracy") ax.set_xlim(*x_lim) ax.set_ylim(*y_lim) # No title — user will add caption externally ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":") ax.set_axisbelow(True) out = os.path.join(REPO_ROOT, "paper/figures/fig_d512L2_panelA.pdf") fig.savefig(out, bbox_inches="tight", dpi=300) fig.savefig(out.replace(".pdf", ".png"), bbox_inches="tight", dpi=200) print(f"Saved: {out}")