summaryrefslogtreecommitdiff
path: root/research/flossing/analyze_tangent.py
blob: 32fde5ebc95a70413234d1139eb49091dd7b1ba6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Analyze the saved tangent basis Q (final, after all ACT steps).

Q shape: (N, seq_full=82, hidden=512, k=4).
Splits by success/failure and computes:
 - position activity per mode: ||Q[:, s, :, i]||_2 averaged
 - hidden-dim activity per mode
 - cosine similarity between succ and fail mode subspaces
"""
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import os

d = np.load("/home/yurenh2/rrm/research/flossing/tangent_modes_512.npz")
Q = d["Q_final"]                  # (N, seq, hidden, k)
exact = d["exact_correct"].astype(bool)  # (N,)
N, seq, hidden, K = Q.shape
print(f"N={N}, seq={seq}, hidden={hidden}, K={K}, acc={exact.mean():.3f}")

OUT = "/home/yurenh2/rrm/research/flossing/plots_tangent"
os.makedirs(OUT, exist_ok=True)

# Position activity per mode: ||Q[:, s, :, i]||_2 averaged across samples in group
def pos_activity(Q_, exact_mask):
    # for each mode i and each position s, compute average ||Q[s, :, i]||
    sub = Q_[exact_mask]  # (n, seq, hidden, k)
    acts = np.linalg.norm(sub, axis=2)  # (n, seq, k) — L2 norm over hidden dim
    return acts.mean(axis=0), acts.std(axis=0)  # (seq, k)

pa_s_mean, pa_s_std = pos_activity(Q, exact)
pa_f_mean, pa_f_std = pos_activity(Q, ~exact)

# Hidden activity per mode
def hid_activity(Q_, exact_mask):
    sub = Q_[exact_mask]  # (n, seq, hidden, k)
    acts = np.linalg.norm(sub, axis=1)  # (n, hidden, k) — L2 norm over positions
    return acts.mean(axis=0), acts.std(axis=0)  # (hidden, k)

ha_s_mean, _ = hid_activity(Q, exact)
ha_f_mean, _ = hid_activity(Q, ~exact)

# ---- Plot 1: position activity per mode ----
fig, axes = plt.subplots(2, 2, figsize=(13, 7), sharex=True, sharey=True)
for i, ax in enumerate(axes.flat):
    if i >= K: ax.set_visible(False); continue
    x = np.arange(seq)
    ax.plot(x, pa_s_mean[:, i], "C0-", label="succ", lw=1.5)
    ax.fill_between(x, pa_s_mean[:, i]-pa_s_std[:, i], pa_s_mean[:, i]+pa_s_std[:, i], color="C0", alpha=0.2)
    ax.plot(x, pa_f_mean[:, i], "C3-", label="fail", lw=1.5)
    ax.fill_between(x, pa_f_mean[:, i]-pa_f_std[:, i], pa_f_mean[:, i]+pa_f_std[:, i], color="C3", alpha=0.2)
    ax.axvline(0.5, color="k", ls=":", lw=0.6)  # mark puzzle_emb boundary
    ax.set_title(f"mode {i+1}: position activity")
    ax.set_xlabel("seq position (0=puzzle_emb, 1-81=Sudoku cells)")
    ax.set_ylabel(r"$\|Q[\cdot,:,i]\|_2$")
    ax.legend()
    ax.grid(alpha=0.3)
fig.suptitle(f"Per-mode position activity (N={N}, acc={exact.mean():.2%})")
fig.tight_layout()
fig.savefig(f"{OUT}/position_activity.png", dpi=130)
plt.close()

# ---- Plot 2: position activity as 9x9 grid (positions 1-81 = Sudoku cells) ----
fig, axes = plt.subplots(2, K, figsize=(K*3, 6))
for i in range(K):
    s_grid = pa_s_mean[1:82, i].reshape(9, 9)
    f_grid = pa_f_mean[1:82, i].reshape(9, 9)
    vmax = max(s_grid.max(), f_grid.max())
    im0 = axes[0, i].imshow(s_grid, vmin=0, vmax=vmax, cmap="viridis")
    axes[0, i].set_title(f"succ mode {i+1}")
    axes[0, i].set_xticks([]); axes[0, i].set_yticks([])
    im1 = axes[1, i].imshow(f_grid, vmin=0, vmax=vmax, cmap="viridis")
    axes[1, i].set_title(f"fail mode {i+1}")
    axes[1, i].set_xticks([]); axes[1, i].set_yticks([])
    plt.colorbar(im1, ax=axes[:, i], fraction=0.046)
fig.suptitle("Tangent mode activity over the 9x9 Sudoku board")
fig.savefig(f"{OUT}/sudoku_grid_activity.png", dpi=130, bbox_inches="tight")
plt.close()

# ---- Plot 3: hidden activity per mode ----
fig, axes = plt.subplots(2, 2, figsize=(13, 7), sharex=True)
for i, ax in enumerate(axes.flat):
    if i >= K: ax.set_visible(False); continue
    x = np.arange(hidden)
    # sort for cleaner viewing
    order = np.argsort(-(ha_s_mean[:, i] + ha_f_mean[:, i]))
    ax.plot(ha_s_mean[order, i], "C0-", label="succ (sorted)", lw=1)
    ax.plot(ha_f_mean[order, i], "C3-", label="fail (sorted)", lw=1)
    ax.set_title(f"mode {i+1}: hidden-dim activity (sorted by sum)")
    ax.set_xlabel("hidden dim (sorted)")
    ax.set_ylabel(r"$\|Q[:,h,i]\|_2$")
    ax.legend()
    ax.grid(alpha=0.3)
fig.tight_layout()
fig.savefig(f"{OUT}/hidden_activity.png", dpi=130)
plt.close()

# ---- Subspace cosine analysis ----
# Flatten Q to (N, seq*hidden, k). The k columns span a k-d subspace per sample.
# Average outer projector P_succ = (1/n) Σ Q_n Q_n^T (over success samples). Same for fail.
# Then compare top eigenvalues / overlap.
Qflat = Q.reshape(N, seq*hidden, K)
def projector_sample(qf):
    # qf: (state_dim, k). Already orthonormal columns since QR'd at end. Returns top-r eigvecs of P.
    # We'll compute the trace of P_s @ P_f as a "subspace overlap": Σ_ij ||q_s^i · q_f^j||²
    return qf

def avg_pos_outer(qflat_subset):
    # Average of q q^T over samples — too big (state_dim x state_dim). Instead average the
    # k x k Gram matrices won't help since each sample has its own orthonormal basis.
    # Compute average squared norm of cross-projections instead:
    # For pairs (s_a, s_b) in same group, compute ||q_a^T q_b||_F² / k (subspace overlap)
    n = qflat_subset.shape[0]
    # Sample a subset of pairs to keep this manageable
    rng = np.random.default_rng(0)
    pairs = min(n*(n-1)//2, 5000)
    if n < 2: return np.nan, np.nan
    overlaps = []
    for _ in range(pairs):
        a, b = rng.choice(n, size=2, replace=False)
        M = qflat_subset[a].T @ qflat_subset[b]  # (k, k)
        overlaps.append((M**2).sum() / qflat_subset.shape[-1])
    return float(np.mean(overlaps)), float(np.std(overlaps))

print("\nSubspace overlap (squared cosine, averaged over random pairs):")
ovl_ss, std_ss = avg_pos_outer(Qflat[exact])
ovl_ff, std_ff = avg_pos_outer(Qflat[~exact])
# Cross-group overlap
def cross_overlap(A, B):
    nA, nB = A.shape[0], B.shape[0]
    rng = np.random.default_rng(0)
    pairs = min(nA*nB, 5000)
    overlaps = []
    for _ in range(pairs):
        a = rng.integers(nA); b = rng.integers(nB)
        M = A[a].T @ B[b]
        overlaps.append((M**2).sum() / A.shape[-1])
    return float(np.mean(overlaps)), float(np.std(overlaps))
ovl_sf, std_sf = cross_overlap(Qflat[exact], Qflat[~exact])

print(f"  succ-succ: {ovl_ss:.4f} ± {std_ss:.4f}")
print(f"  fail-fail: {ovl_ff:.4f} ± {std_ff:.4f}")
print(f"  succ-fail: {ovl_sf:.4f} ± {std_sf:.4f}")

# ---- Final ratio of subspace overlap inside vs cross-group ----
# If succ-succ and fail-fail >> succ-fail, the two groups live in DIFFERENT subspaces.
# If they're similar, the unstable subspace is shared.

print("\nplots saved to", OUT)