"""Analyze the saved tangent basis Q (final, after all ACT steps). Q shape: (N, seq_full=82, hidden=512, k=4). Splits by success/failure and computes: - position activity per mode: ||Q[:, s, :, i]||_2 averaged - hidden-dim activity per mode - cosine similarity between succ and fail mode subspaces """ import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import os d = np.load("/home/yurenh2/rrm/research/flossing/tangent_modes_512.npz") Q = d["Q_final"] # (N, seq, hidden, k) exact = d["exact_correct"].astype(bool) # (N,) N, seq, hidden, K = Q.shape print(f"N={N}, seq={seq}, hidden={hidden}, K={K}, acc={exact.mean():.3f}") OUT = "/home/yurenh2/rrm/research/flossing/plots_tangent" os.makedirs(OUT, exist_ok=True) # Position activity per mode: ||Q[:, s, :, i]||_2 averaged across samples in group def pos_activity(Q_, exact_mask): # for each mode i and each position s, compute average ||Q[s, :, i]|| sub = Q_[exact_mask] # (n, seq, hidden, k) acts = np.linalg.norm(sub, axis=2) # (n, seq, k) — L2 norm over hidden dim return acts.mean(axis=0), acts.std(axis=0) # (seq, k) pa_s_mean, pa_s_std = pos_activity(Q, exact) pa_f_mean, pa_f_std = pos_activity(Q, ~exact) # Hidden activity per mode def hid_activity(Q_, exact_mask): sub = Q_[exact_mask] # (n, seq, hidden, k) acts = np.linalg.norm(sub, axis=1) # (n, hidden, k) — L2 norm over positions return acts.mean(axis=0), acts.std(axis=0) # (hidden, k) ha_s_mean, _ = hid_activity(Q, exact) ha_f_mean, _ = hid_activity(Q, ~exact) # ---- Plot 1: position activity per mode ---- fig, axes = plt.subplots(2, 2, figsize=(13, 7), sharex=True, sharey=True) for i, ax in enumerate(axes.flat): if i >= K: ax.set_visible(False); continue x = np.arange(seq) ax.plot(x, pa_s_mean[:, i], "C0-", label="succ", lw=1.5) ax.fill_between(x, pa_s_mean[:, i]-pa_s_std[:, i], pa_s_mean[:, i]+pa_s_std[:, i], color="C0", alpha=0.2) ax.plot(x, pa_f_mean[:, i], "C3-", label="fail", lw=1.5) ax.fill_between(x, pa_f_mean[:, i]-pa_f_std[:, i], pa_f_mean[:, i]+pa_f_std[:, i], color="C3", alpha=0.2) ax.axvline(0.5, color="k", ls=":", lw=0.6) # mark puzzle_emb boundary ax.set_title(f"mode {i+1}: position activity") ax.set_xlabel("seq position (0=puzzle_emb, 1-81=Sudoku cells)") ax.set_ylabel(r"$\|Q[\cdot,:,i]\|_2$") ax.legend() ax.grid(alpha=0.3) fig.suptitle(f"Per-mode position activity (N={N}, acc={exact.mean():.2%})") fig.tight_layout() fig.savefig(f"{OUT}/position_activity.png", dpi=130) plt.close() # ---- Plot 2: position activity as 9x9 grid (positions 1-81 = Sudoku cells) ---- fig, axes = plt.subplots(2, K, figsize=(K*3, 6)) for i in range(K): s_grid = pa_s_mean[1:82, i].reshape(9, 9) f_grid = pa_f_mean[1:82, i].reshape(9, 9) vmax = max(s_grid.max(), f_grid.max()) im0 = axes[0, i].imshow(s_grid, vmin=0, vmax=vmax, cmap="viridis") axes[0, i].set_title(f"succ mode {i+1}") axes[0, i].set_xticks([]); axes[0, i].set_yticks([]) im1 = axes[1, i].imshow(f_grid, vmin=0, vmax=vmax, cmap="viridis") axes[1, i].set_title(f"fail mode {i+1}") axes[1, i].set_xticks([]); axes[1, i].set_yticks([]) plt.colorbar(im1, ax=axes[:, i], fraction=0.046) fig.suptitle("Tangent mode activity over the 9x9 Sudoku board") fig.savefig(f"{OUT}/sudoku_grid_activity.png", dpi=130, bbox_inches="tight") plt.close() # ---- Plot 3: hidden activity per mode ---- fig, axes = plt.subplots(2, 2, figsize=(13, 7), sharex=True) for i, ax in enumerate(axes.flat): if i >= K: ax.set_visible(False); continue x = np.arange(hidden) # sort for cleaner viewing order = np.argsort(-(ha_s_mean[:, i] + ha_f_mean[:, i])) ax.plot(ha_s_mean[order, i], "C0-", label="succ (sorted)", lw=1) ax.plot(ha_f_mean[order, i], "C3-", label="fail (sorted)", lw=1) ax.set_title(f"mode {i+1}: hidden-dim activity (sorted by sum)") ax.set_xlabel("hidden dim (sorted)") ax.set_ylabel(r"$\|Q[:,h,i]\|_2$") ax.legend() ax.grid(alpha=0.3) fig.tight_layout() fig.savefig(f"{OUT}/hidden_activity.png", dpi=130) plt.close() # ---- Subspace cosine analysis ---- # Flatten Q to (N, seq*hidden, k). The k columns span a k-d subspace per sample. # Average outer projector P_succ = (1/n) Σ Q_n Q_n^T (over success samples). Same for fail. # Then compare top eigenvalues / overlap. Qflat = Q.reshape(N, seq*hidden, K) def projector_sample(qf): # qf: (state_dim, k). Already orthonormal columns since QR'd at end. Returns top-r eigvecs of P. # We'll compute the trace of P_s @ P_f as a "subspace overlap": Σ_ij ||q_s^i · q_f^j||² return qf def avg_pos_outer(qflat_subset): # Average of q q^T over samples — too big (state_dim x state_dim). Instead average the # k x k Gram matrices won't help since each sample has its own orthonormal basis. # Compute average squared norm of cross-projections instead: # For pairs (s_a, s_b) in same group, compute ||q_a^T q_b||_F² / k (subspace overlap) n = qflat_subset.shape[0] # Sample a subset of pairs to keep this manageable rng = np.random.default_rng(0) pairs = min(n*(n-1)//2, 5000) if n < 2: return np.nan, np.nan overlaps = [] for _ in range(pairs): a, b = rng.choice(n, size=2, replace=False) M = qflat_subset[a].T @ qflat_subset[b] # (k, k) overlaps.append((M**2).sum() / qflat_subset.shape[-1]) return float(np.mean(overlaps)), float(np.std(overlaps)) print("\nSubspace overlap (squared cosine, averaged over random pairs):") ovl_ss, std_ss = avg_pos_outer(Qflat[exact]) ovl_ff, std_ff = avg_pos_outer(Qflat[~exact]) # Cross-group overlap def cross_overlap(A, B): nA, nB = A.shape[0], B.shape[0] rng = np.random.default_rng(0) pairs = min(nA*nB, 5000) overlaps = [] for _ in range(pairs): a = rng.integers(nA); b = rng.integers(nB) M = A[a].T @ B[b] overlaps.append((M**2).sum() / A.shape[-1]) return float(np.mean(overlaps)), float(np.std(overlaps)) ovl_sf, std_sf = cross_overlap(Qflat[exact], Qflat[~exact]) print(f" succ-succ: {ovl_ss:.4f} ± {std_ss:.4f}") print(f" fail-fail: {ovl_ff:.4f} ± {std_ff:.4f}") print(f" succ-fail: {ovl_sf:.4f} ± {std_sf:.4f}") # ---- Final ratio of subspace overlap inside vs cross-group ---- # If succ-succ and fail-fail >> succ-fail, the two groups live in DIFFERENT subspaces. # If they're similar, the unstable subspace is shared. print("\nplots saved to", OUT)