summaryrefslogtreecommitdiff
path: root/paper/figures/render_fig_d512L2_panelA.py
blob: b8fabf8cde3cbb9e243df246e5fb5f6f00c6ad72 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Panel A style scatter for d=512 L=2 qualifying seeds, with frozen baseline line."""
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np

REPO_ROOT = "/home/yurenh2/fa"

# Per-seed data for qualifying seeds 1, 2, 5
per_seed = {
    "BP":  [{"acc": 0.6061, "gamma": 1.0},
            {"acc": 0.6076, "gamma": 1.0},
            {"acc": 0.6065, "gamma": 1.0}],
    "FA":  [{"acc": 0.3471, "gamma": 0.4840},
            {"acc": 0.3464, "gamma": 0.4721},
            {"acc": 0.3410, "gamma": 0.4924}],
    "DFA": [{"acc": 0.2978, "gamma": 0.2062},
            {"acc": 0.2968, "gamma": 0.1786},
            {"acc": 0.2963, "gamma": 0.1940}],
}
FROZEN = 0.349

COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"}
MARKERS = {"BP": "o", "FA": "s", "DFA": "D"}

plt.rcParams.update({
    "font.size": 9, "axes.labelsize": 10, "axes.titlesize": 10,
    "xtick.labelsize": 8, "ytick.labelsize": 8, "font.family": "serif",
})

fig, ax = plt.subplots(figsize=(4.0, 3.5))

# Axes swapped: x = Γ (cosine), y = accuracy
x_lim = (-0.22, 1.12)
y_lim = (-0.02, 0.72)

# Four quadrants: boundaries at x=0 (cos=0) and y=0.10 (chance)
ax.fill_between([0, x_lim[1]], 0.10, y_lim[1], color="#c8e6c9", alpha=0.5, zorder=0)
ax.fill_between([x_lim[0], 0], y_lim[0], 0.10, color="#ffcdd2", alpha=0.5, zorder=0)
ax.fill_between([x_lim[0], 0], 0.10, y_lim[1], color="#f5f5f5", alpha=0.6, zorder=0)
ax.fill_between([0, x_lim[1]], y_lim[0], 0.10, color="#f5f5f5", alpha=0.6, zorder=0)

ax.axvline(0, color="gray", lw=0.6, ls="--", zorder=1)
ax.axhline(0.10, color="gray", lw=0.6, ls="--", zorder=1)

# Frozen baseline horizontal line (acc = FROZEN)
ax.axhline(FROZEN, color="#555", lw=1.2, ls=":", zorder=2)
ax.text(1.05, FROZEN + 0.01, f"frozen baseline ({FROZEN:.3f})", fontsize=7,
        color="#555", ha="right", va="bottom")

# Quadrant labels
ax.text(-0.11, 0.45, "cos < 0,\nacc > chance", fontsize=6.5, color="#888",
        ha="center", va="center", style="italic", rotation=90)
ax.text(0.55, 0.04, "cos > 0,\nacc < chance", fontsize=6.5, color="#888",
        ha="center", va="center", style="italic")
ax.text(0.55, 0.45, '"looks like\n  learning"', fontsize=7, color="#388e3c",
        ha="center", va="center", fontweight="bold")
ax.text(-0.11, 0.04, "neither", fontsize=6.5, color="#c62828",
        ha="center", va="center", style="italic")

# Plot all 3 seeds per method: x=gamma, y=acc
for method in ["BP", "FA", "DFA"]:
    seeds = per_seed[method]
    gammas = [s["gamma"] for s in seeds]
    accs = [s["acc"] for s in seeds]
    ax.scatter(gammas, accs, c=COLORS[method], marker=MARKERS[method],
              s=60, zorder=5, edgecolors="k", linewidths=0.4, label=method)

# Labels — annotate near the centroid of each cluster
for method, offsets in [("BP", (-8, 8)), ("FA", (8, -10)), ("DFA", (8, -10))]:
    seeds = per_seed[method]
    cx = np.mean([s["gamma"] for s in seeds])
    cy = np.mean([s["acc"] for s in seeds])
    ax.annotate(method, (cx, cy),
               xytext=offsets, textcoords="offset points", fontsize=9, fontweight="bold",
               color=COLORS[method])

ax.set_xlabel("Aggregate $\\Gamma$ (cosine)")
ax.set_ylabel("Test accuracy")
ax.set_xlim(*x_lim)
ax.set_ylim(*y_lim)
# No title — user will add caption externally

ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":")
ax.set_axisbelow(True)

out = os.path.join(REPO_ROOT, "paper/figures/fig_d512L2_panelA.pdf")
fig.savefig(out, bbox_inches="tight", dpi=300)
fig.savefig(out.replace(".pdf", ".png"), bbox_inches="tight", dpi=200)
print(f"Saved: {out}")