1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
|
"""Panel A style scatter for d=512 L=2 qualifying seeds, with frozen baseline line."""
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
REPO_ROOT = "/home/yurenh2/fa"
# Per-seed data for qualifying seeds 1, 2, 5
per_seed = {
"BP": [{"acc": 0.6061, "gamma": 1.0},
{"acc": 0.6076, "gamma": 1.0},
{"acc": 0.6065, "gamma": 1.0}],
"FA": [{"acc": 0.3471, "gamma": 0.4840},
{"acc": 0.3464, "gamma": 0.4721},
{"acc": 0.3410, "gamma": 0.4924}],
"DFA": [{"acc": 0.2978, "gamma": 0.2062},
{"acc": 0.2968, "gamma": 0.1786},
{"acc": 0.2963, "gamma": 0.1940}],
}
FROZEN = 0.349
COLORS = {"BP": "#2166ac", "FA": "#e08214", "DFA": "#b2182b"}
MARKERS = {"BP": "o", "FA": "s", "DFA": "D"}
plt.rcParams.update({
"font.size": 9, "axes.labelsize": 10, "axes.titlesize": 10,
"xtick.labelsize": 8, "ytick.labelsize": 8, "font.family": "serif",
})
fig, ax = plt.subplots(figsize=(4.0, 3.5))
# Axes swapped: x = Γ (cosine), y = accuracy
x_lim = (-0.22, 1.12)
y_lim = (-0.02, 0.72)
# Four quadrants: boundaries at x=0 (cos=0) and y=0.10 (chance)
ax.fill_between([0, x_lim[1]], 0.10, y_lim[1], color="#c8e6c9", alpha=0.5, zorder=0)
ax.fill_between([x_lim[0], 0], y_lim[0], 0.10, color="#ffcdd2", alpha=0.5, zorder=0)
ax.fill_between([x_lim[0], 0], 0.10, y_lim[1], color="#f5f5f5", alpha=0.6, zorder=0)
ax.fill_between([0, x_lim[1]], y_lim[0], 0.10, color="#f5f5f5", alpha=0.6, zorder=0)
ax.axvline(0, color="gray", lw=0.6, ls="--", zorder=1)
ax.axhline(0.10, color="gray", lw=0.6, ls="--", zorder=1)
# Frozen baseline horizontal line (acc = FROZEN)
ax.axhline(FROZEN, color="#555", lw=1.2, ls=":", zorder=2)
ax.text(1.05, FROZEN + 0.01, f"frozen baseline ({FROZEN:.3f})", fontsize=7,
color="#555", ha="right", va="bottom")
# Quadrant labels
ax.text(-0.11, 0.45, "cos < 0,\nacc > chance", fontsize=6.5, color="#888",
ha="center", va="center", style="italic", rotation=90)
ax.text(0.55, 0.04, "cos > 0,\nacc < chance", fontsize=6.5, color="#888",
ha="center", va="center", style="italic")
ax.text(0.55, 0.45, '"looks like\n learning"', fontsize=7, color="#388e3c",
ha="center", va="center", fontweight="bold")
ax.text(-0.11, 0.04, "neither", fontsize=6.5, color="#c62828",
ha="center", va="center", style="italic")
# Plot all 3 seeds per method: x=gamma, y=acc
for method in ["BP", "FA", "DFA"]:
seeds = per_seed[method]
gammas = [s["gamma"] for s in seeds]
accs = [s["acc"] for s in seeds]
ax.scatter(gammas, accs, c=COLORS[method], marker=MARKERS[method],
s=60, zorder=5, edgecolors="k", linewidths=0.4, label=method)
# Labels — annotate near the centroid of each cluster
for method, offsets in [("BP", (-8, 8)), ("FA", (8, -10)), ("DFA", (8, -10))]:
seeds = per_seed[method]
cx = np.mean([s["gamma"] for s in seeds])
cy = np.mean([s["acc"] for s in seeds])
ax.annotate(method, (cx, cy),
xytext=offsets, textcoords="offset points", fontsize=9, fontweight="bold",
color=COLORS[method])
ax.set_xlabel("Aggregate $\\Gamma$ (cosine)")
ax.set_ylabel("Test accuracy")
ax.set_xlim(*x_lim)
ax.set_ylim(*y_lim)
# No title — user will add caption externally
ax.grid(True, which="major", color="#d0d0d0", linewidth=0.4, linestyle=":")
ax.set_axisbelow(True)
out = os.path.join(REPO_ROOT, "paper/figures/fig_d512L2_panelA.pdf")
fig.savefig(out, bbox_inches="tight", dpi=300)
fig.savefig(out.replace(".pdf", ".png"), bbox_inches="tight", dpi=200)
print(f"Saved: {out}")
|