summaryrefslogtreecommitdiff
path: root/paper/figures/render_fig_cos_acc_dissociation.py
blob: fff7f65123be82bc4ddcdcbe258bdbdafd4b3e96 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Render Figure: cos-vs-accuracy cross-method dissociation.

Shows the v2.33 finding: under matched penalty rescue (lam=1e-2, 30ep, 3 seeds)
on the audited 4-block d=256 ResMLP, three independent functional metrics
(headline accuracy, single-step nudging, integrated training-loss decrease)
all rank SB ≫ CB ≈ DFA, while deep cosine ranks CB > SB > DFA — the only
ordering that disagrees with the functional ranking.

Sources (all 3-seed):
  results/round38_sb_penalty_30ep_s{42,123,456}/results_cifar10.json
  results/round38_cb_penalty_30ep_s{42,123,456}/results_cifar10.json
  results/round41_dfa_penalty_30ep{,_s{123,456}}/results_cifar10.json
  results/nudging_test_3seed_summary.json
  results/training_loss_decrease_3seed.json
"""
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

REPO_ROOT = "/home/yurenh2/fa"

# Three-seed values from the saved JSONs (cross-checked against §4 ¶4 prose)
methods = ["SB+pen", "CB+pen", "DFA+pen"]
colors = {"SB+pen": "#1f77b4", "CB+pen": "#d62728", "DFA+pen": "#7f7f7f"}

# Each entry: (raw values per method, with std if available)
# §4 ¶4 lists the three functional metrics as accuracy, nudging, training-loss
# trajectory. Deep ρ is intentionally NOT shown here because ρ ranks CB > SB > DFA
# (same as cos), not SB > CB > DFA — ρ groups with cos as a "directional alignment"
# metric, while the functional triad below groups around forward-state usefulness.
metrics = {
    "deep cos":             [0.322, 0.679, 0.151],
    "accuracy":             [0.453, 0.360, 0.360],
    "|nudging|":            [1.929e-3, 4.264e-4, 4.978e-5],
    "loss decrease":        [0.447, 0.121, 0.095],
}
metric_stds = {
    "deep cos":             [0.007, 0.008, 0.025],
    "accuracy":             [0.003, 0.003, 0.001],
    "|nudging|":            [0.113e-3, 0.024e-3, 0.0044e-3],
    "loss decrease":        [0.008, 0.003, 0.007],
}

# Normalize each metric to [0, 1] where 1 = max across the 3 methods.
# This makes the parallel-coordinates lines comparable.
metric_names = list(metrics.keys())
norm = {}
for m, vals in metrics.items():
    mx = max(vals)
    norm[m] = [v / mx for v in vals]

fig, ax = plt.subplots(figsize=(6.0, 3.4))

x = np.arange(len(metric_names))

for i, method in enumerate(methods):
    y = [norm[m][i] for m in metric_names]
    ax.plot(x, y, "o-", color=colors[method], lw=2.2, markersize=9, label=method)
    # Annotate each point with the raw value
    for xi, yi, m in zip(x, y, metric_names):
        raw = metrics[m][i]
        if "nudg" in m:
            label = f"{raw*1e3:.2f}e-3"
        elif "cos" in m:
            label = f"+{raw:.3f}" if raw >= 0 else f"{raw:.3f}"
        else:
            label = f"{raw:.3f}"
        # Place label slightly offset based on method ordering at this x
        ax.annotate(label, (xi, yi), textcoords="offset points",
                    xytext=(8, 0), fontsize=7, color=colors[method],
                    ha="left", va="center")

ax.set_xticks(x)
ax.set_xticklabels(metric_names, fontsize=9)
ax.set_ylabel("normalized score (max = 1 across the 3 methods)", fontsize=9)
ax.set_ylim(-0.05, 1.18)
ax.set_title("Cross-method functional dissociation (3 seeds, 30 ep, $\\lambda{=}10^{-2}$)\n"
             "all 3 functional metrics rank SB $\\gg$ CB $\\approx$ DFA; deep cos is the only one that disagrees",
             fontsize=9)
ax.legend(loc="upper right", fontsize=8, framealpha=0.95)
ax.grid(True, axis="y", alpha=0.3)

# Visual guide: shade the "cos column disagrees" region
ax.axvspan(-0.4, 0.4, color="#fff3e0", alpha=0.5, zorder=0)
ax.text(0, 1.13, "cos: CB top", ha="center", fontsize=7, color="#cc4400", style="italic")
ax.text(2.5, 1.13, "accuracy / nudging / training-loss decrease: SB top", ha="center", fontsize=7, color="#1f5f9f", style="italic")

plt.tight_layout()
out = os.path.join(REPO_ROOT, "paper/figures/fig_cos_acc_dissociation.pdf")
plt.savefig(out, bbox_inches="tight", dpi=200)
print(f"Saved {out}")