diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/analyze_tangent.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/analyze_tangent.py')
| -rw-r--r-- | research/flossing/analyze_tangent.py | 149 |
1 files changed, 149 insertions, 0 deletions
diff --git a/research/flossing/analyze_tangent.py b/research/flossing/analyze_tangent.py new file mode 100644 index 0000000..32fde5e --- /dev/null +++ b/research/flossing/analyze_tangent.py @@ -0,0 +1,149 @@ +"""Analyze the saved tangent basis Q (final, after all ACT steps). + +Q shape: (N, seq_full=82, hidden=512, k=4). +Splits by success/failure and computes: + - position activity per mode: ||Q[:, s, :, i]||_2 averaged + - hidden-dim activity per mode + - cosine similarity between succ and fail mode subspaces +""" +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import os + +d = np.load("/home/yurenh2/rrm/research/flossing/tangent_modes_512.npz") +Q = d["Q_final"] # (N, seq, hidden, k) +exact = d["exact_correct"].astype(bool) # (N,) +N, seq, hidden, K = Q.shape +print(f"N={N}, seq={seq}, hidden={hidden}, K={K}, acc={exact.mean():.3f}") + +OUT = "/home/yurenh2/rrm/research/flossing/plots_tangent" +os.makedirs(OUT, exist_ok=True) + +# Position activity per mode: ||Q[:, s, :, i]||_2 averaged across samples in group +def pos_activity(Q_, exact_mask): + # for each mode i and each position s, compute average ||Q[s, :, i]|| + sub = Q_[exact_mask] # (n, seq, hidden, k) + acts = np.linalg.norm(sub, axis=2) # (n, seq, k) — L2 norm over hidden dim + return acts.mean(axis=0), acts.std(axis=0) # (seq, k) + +pa_s_mean, pa_s_std = pos_activity(Q, exact) +pa_f_mean, pa_f_std = pos_activity(Q, ~exact) + +# Hidden activity per mode +def hid_activity(Q_, exact_mask): + sub = Q_[exact_mask] # (n, seq, hidden, k) + acts = np.linalg.norm(sub, axis=1) # (n, hidden, k) — L2 norm over positions + return acts.mean(axis=0), acts.std(axis=0) # (hidden, k) + +ha_s_mean, _ = hid_activity(Q, exact) +ha_f_mean, _ = hid_activity(Q, ~exact) + +# ---- Plot 1: position activity per mode ---- +fig, axes = plt.subplots(2, 2, figsize=(13, 7), sharex=True, sharey=True) +for i, ax in enumerate(axes.flat): + if i >= K: ax.set_visible(False); continue + x = np.arange(seq) + ax.plot(x, pa_s_mean[:, i], "C0-", label="succ", lw=1.5) + ax.fill_between(x, pa_s_mean[:, i]-pa_s_std[:, i], pa_s_mean[:, i]+pa_s_std[:, i], color="C0", alpha=0.2) + ax.plot(x, pa_f_mean[:, i], "C3-", label="fail", lw=1.5) + ax.fill_between(x, pa_f_mean[:, i]-pa_f_std[:, i], pa_f_mean[:, i]+pa_f_std[:, i], color="C3", alpha=0.2) + ax.axvline(0.5, color="k", ls=":", lw=0.6) # mark puzzle_emb boundary + ax.set_title(f"mode {i+1}: position activity") + ax.set_xlabel("seq position (0=puzzle_emb, 1-81=Sudoku cells)") + ax.set_ylabel(r"$\|Q[\cdot,:,i]\|_2$") + ax.legend() + ax.grid(alpha=0.3) +fig.suptitle(f"Per-mode position activity (N={N}, acc={exact.mean():.2%})") +fig.tight_layout() +fig.savefig(f"{OUT}/position_activity.png", dpi=130) +plt.close() + +# ---- Plot 2: position activity as 9x9 grid (positions 1-81 = Sudoku cells) ---- +fig, axes = plt.subplots(2, K, figsize=(K*3, 6)) +for i in range(K): + s_grid = pa_s_mean[1:82, i].reshape(9, 9) + f_grid = pa_f_mean[1:82, i].reshape(9, 9) + vmax = max(s_grid.max(), f_grid.max()) + im0 = axes[0, i].imshow(s_grid, vmin=0, vmax=vmax, cmap="viridis") + axes[0, i].set_title(f"succ mode {i+1}") + axes[0, i].set_xticks([]); axes[0, i].set_yticks([]) + im1 = axes[1, i].imshow(f_grid, vmin=0, vmax=vmax, cmap="viridis") + axes[1, i].set_title(f"fail mode {i+1}") + axes[1, i].set_xticks([]); axes[1, i].set_yticks([]) + plt.colorbar(im1, ax=axes[:, i], fraction=0.046) +fig.suptitle("Tangent mode activity over the 9x9 Sudoku board") +fig.savefig(f"{OUT}/sudoku_grid_activity.png", dpi=130, bbox_inches="tight") +plt.close() + +# ---- Plot 3: hidden activity per mode ---- +fig, axes = plt.subplots(2, 2, figsize=(13, 7), sharex=True) +for i, ax in enumerate(axes.flat): + if i >= K: ax.set_visible(False); continue + x = np.arange(hidden) + # sort for cleaner viewing + order = np.argsort(-(ha_s_mean[:, i] + ha_f_mean[:, i])) + ax.plot(ha_s_mean[order, i], "C0-", label="succ (sorted)", lw=1) + ax.plot(ha_f_mean[order, i], "C3-", label="fail (sorted)", lw=1) + ax.set_title(f"mode {i+1}: hidden-dim activity (sorted by sum)") + ax.set_xlabel("hidden dim (sorted)") + ax.set_ylabel(r"$\|Q[:,h,i]\|_2$") + ax.legend() + ax.grid(alpha=0.3) +fig.tight_layout() +fig.savefig(f"{OUT}/hidden_activity.png", dpi=130) +plt.close() + +# ---- Subspace cosine analysis ---- +# Flatten Q to (N, seq*hidden, k). The k columns span a k-d subspace per sample. +# Average outer projector P_succ = (1/n) Σ Q_n Q_n^T (over success samples). Same for fail. +# Then compare top eigenvalues / overlap. +Qflat = Q.reshape(N, seq*hidden, K) +def projector_sample(qf): + # qf: (state_dim, k). Already orthonormal columns since QR'd at end. Returns top-r eigvecs of P. + # We'll compute the trace of P_s @ P_f as a "subspace overlap": Σ_ij ||q_s^i · q_f^j||² + return qf + +def avg_pos_outer(qflat_subset): + # Average of q q^T over samples — too big (state_dim x state_dim). Instead average the + # k x k Gram matrices won't help since each sample has its own orthonormal basis. + # Compute average squared norm of cross-projections instead: + # For pairs (s_a, s_b) in same group, compute ||q_a^T q_b||_F² / k (subspace overlap) + n = qflat_subset.shape[0] + # Sample a subset of pairs to keep this manageable + rng = np.random.default_rng(0) + pairs = min(n*(n-1)//2, 5000) + if n < 2: return np.nan, np.nan + overlaps = [] + for _ in range(pairs): + a, b = rng.choice(n, size=2, replace=False) + M = qflat_subset[a].T @ qflat_subset[b] # (k, k) + overlaps.append((M**2).sum() / qflat_subset.shape[-1]) + return float(np.mean(overlaps)), float(np.std(overlaps)) + +print("\nSubspace overlap (squared cosine, averaged over random pairs):") +ovl_ss, std_ss = avg_pos_outer(Qflat[exact]) +ovl_ff, std_ff = avg_pos_outer(Qflat[~exact]) +# Cross-group overlap +def cross_overlap(A, B): + nA, nB = A.shape[0], B.shape[0] + rng = np.random.default_rng(0) + pairs = min(nA*nB, 5000) + overlaps = [] + for _ in range(pairs): + a = rng.integers(nA); b = rng.integers(nB) + M = A[a].T @ B[b] + overlaps.append((M**2).sum() / A.shape[-1]) + return float(np.mean(overlaps)), float(np.std(overlaps)) +ovl_sf, std_sf = cross_overlap(Qflat[exact], Qflat[~exact]) + +print(f" succ-succ: {ovl_ss:.4f} ± {std_ss:.4f}") +print(f" fail-fail: {ovl_ff:.4f} ± {std_ff:.4f}") +print(f" succ-fail: {ovl_sf:.4f} ± {std_sf:.4f}") + +# ---- Final ratio of subspace overlap inside vs cross-group ---- +# If succ-succ and fail-fail >> succ-fail, the two groups live in DIFFERENT subspaces. +# If they're similar, the unstable subspace is shared. + +print("\nplots saved to", OUT) |
