summaryrefslogtreecommitdiff
path: root/research/flossing/analyze_tangent.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/analyze_tangent.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/analyze_tangent.py')
-rw-r--r--research/flossing/analyze_tangent.py149
1 files changed, 149 insertions, 0 deletions
diff --git a/research/flossing/analyze_tangent.py b/research/flossing/analyze_tangent.py
new file mode 100644
index 0000000..32fde5e
--- /dev/null
+++ b/research/flossing/analyze_tangent.py
@@ -0,0 +1,149 @@
+"""Analyze the saved tangent basis Q (final, after all ACT steps).
+
+Q shape: (N, seq_full=82, hidden=512, k=4).
+Splits by success/failure and computes:
+ - position activity per mode: ||Q[:, s, :, i]||_2 averaged
+ - hidden-dim activity per mode
+ - cosine similarity between succ and fail mode subspaces
+"""
+import numpy as np
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import os
+
+d = np.load("/home/yurenh2/rrm/research/flossing/tangent_modes_512.npz")
+Q = d["Q_final"] # (N, seq, hidden, k)
+exact = d["exact_correct"].astype(bool) # (N,)
+N, seq, hidden, K = Q.shape
+print(f"N={N}, seq={seq}, hidden={hidden}, K={K}, acc={exact.mean():.3f}")
+
+OUT = "/home/yurenh2/rrm/research/flossing/plots_tangent"
+os.makedirs(OUT, exist_ok=True)
+
+# Position activity per mode: ||Q[:, s, :, i]||_2 averaged across samples in group
+def pos_activity(Q_, exact_mask):
+ # for each mode i and each position s, compute average ||Q[s, :, i]||
+ sub = Q_[exact_mask] # (n, seq, hidden, k)
+ acts = np.linalg.norm(sub, axis=2) # (n, seq, k) — L2 norm over hidden dim
+ return acts.mean(axis=0), acts.std(axis=0) # (seq, k)
+
+pa_s_mean, pa_s_std = pos_activity(Q, exact)
+pa_f_mean, pa_f_std = pos_activity(Q, ~exact)
+
+# Hidden activity per mode
+def hid_activity(Q_, exact_mask):
+ sub = Q_[exact_mask] # (n, seq, hidden, k)
+ acts = np.linalg.norm(sub, axis=1) # (n, hidden, k) — L2 norm over positions
+ return acts.mean(axis=0), acts.std(axis=0) # (hidden, k)
+
+ha_s_mean, _ = hid_activity(Q, exact)
+ha_f_mean, _ = hid_activity(Q, ~exact)
+
+# ---- Plot 1: position activity per mode ----
+fig, axes = plt.subplots(2, 2, figsize=(13, 7), sharex=True, sharey=True)
+for i, ax in enumerate(axes.flat):
+ if i >= K: ax.set_visible(False); continue
+ x = np.arange(seq)
+ ax.plot(x, pa_s_mean[:, i], "C0-", label="succ", lw=1.5)
+ ax.fill_between(x, pa_s_mean[:, i]-pa_s_std[:, i], pa_s_mean[:, i]+pa_s_std[:, i], color="C0", alpha=0.2)
+ ax.plot(x, pa_f_mean[:, i], "C3-", label="fail", lw=1.5)
+ ax.fill_between(x, pa_f_mean[:, i]-pa_f_std[:, i], pa_f_mean[:, i]+pa_f_std[:, i], color="C3", alpha=0.2)
+ ax.axvline(0.5, color="k", ls=":", lw=0.6) # mark puzzle_emb boundary
+ ax.set_title(f"mode {i+1}: position activity")
+ ax.set_xlabel("seq position (0=puzzle_emb, 1-81=Sudoku cells)")
+ ax.set_ylabel(r"$\|Q[\cdot,:,i]\|_2$")
+ ax.legend()
+ ax.grid(alpha=0.3)
+fig.suptitle(f"Per-mode position activity (N={N}, acc={exact.mean():.2%})")
+fig.tight_layout()
+fig.savefig(f"{OUT}/position_activity.png", dpi=130)
+plt.close()
+
+# ---- Plot 2: position activity as 9x9 grid (positions 1-81 = Sudoku cells) ----
+fig, axes = plt.subplots(2, K, figsize=(K*3, 6))
+for i in range(K):
+ s_grid = pa_s_mean[1:82, i].reshape(9, 9)
+ f_grid = pa_f_mean[1:82, i].reshape(9, 9)
+ vmax = max(s_grid.max(), f_grid.max())
+ im0 = axes[0, i].imshow(s_grid, vmin=0, vmax=vmax, cmap="viridis")
+ axes[0, i].set_title(f"succ mode {i+1}")
+ axes[0, i].set_xticks([]); axes[0, i].set_yticks([])
+ im1 = axes[1, i].imshow(f_grid, vmin=0, vmax=vmax, cmap="viridis")
+ axes[1, i].set_title(f"fail mode {i+1}")
+ axes[1, i].set_xticks([]); axes[1, i].set_yticks([])
+ plt.colorbar(im1, ax=axes[:, i], fraction=0.046)
+fig.suptitle("Tangent mode activity over the 9x9 Sudoku board")
+fig.savefig(f"{OUT}/sudoku_grid_activity.png", dpi=130, bbox_inches="tight")
+plt.close()
+
+# ---- Plot 3: hidden activity per mode ----
+fig, axes = plt.subplots(2, 2, figsize=(13, 7), sharex=True)
+for i, ax in enumerate(axes.flat):
+ if i >= K: ax.set_visible(False); continue
+ x = np.arange(hidden)
+ # sort for cleaner viewing
+ order = np.argsort(-(ha_s_mean[:, i] + ha_f_mean[:, i]))
+ ax.plot(ha_s_mean[order, i], "C0-", label="succ (sorted)", lw=1)
+ ax.plot(ha_f_mean[order, i], "C3-", label="fail (sorted)", lw=1)
+ ax.set_title(f"mode {i+1}: hidden-dim activity (sorted by sum)")
+ ax.set_xlabel("hidden dim (sorted)")
+ ax.set_ylabel(r"$\|Q[:,h,i]\|_2$")
+ ax.legend()
+ ax.grid(alpha=0.3)
+fig.tight_layout()
+fig.savefig(f"{OUT}/hidden_activity.png", dpi=130)
+plt.close()
+
+# ---- Subspace cosine analysis ----
+# Flatten Q to (N, seq*hidden, k). The k columns span a k-d subspace per sample.
+# Average outer projector P_succ = (1/n) Σ Q_n Q_n^T (over success samples). Same for fail.
+# Then compare top eigenvalues / overlap.
+Qflat = Q.reshape(N, seq*hidden, K)
+def projector_sample(qf):
+ # qf: (state_dim, k). Already orthonormal columns since QR'd at end. Returns top-r eigvecs of P.
+ # We'll compute the trace of P_s @ P_f as a "subspace overlap": Σ_ij ||q_s^i · q_f^j||²
+ return qf
+
+def avg_pos_outer(qflat_subset):
+ # Average of q q^T over samples — too big (state_dim x state_dim). Instead average the
+ # k x k Gram matrices won't help since each sample has its own orthonormal basis.
+ # Compute average squared norm of cross-projections instead:
+ # For pairs (s_a, s_b) in same group, compute ||q_a^T q_b||_F² / k (subspace overlap)
+ n = qflat_subset.shape[0]
+ # Sample a subset of pairs to keep this manageable
+ rng = np.random.default_rng(0)
+ pairs = min(n*(n-1)//2, 5000)
+ if n < 2: return np.nan, np.nan
+ overlaps = []
+ for _ in range(pairs):
+ a, b = rng.choice(n, size=2, replace=False)
+ M = qflat_subset[a].T @ qflat_subset[b] # (k, k)
+ overlaps.append((M**2).sum() / qflat_subset.shape[-1])
+ return float(np.mean(overlaps)), float(np.std(overlaps))
+
+print("\nSubspace overlap (squared cosine, averaged over random pairs):")
+ovl_ss, std_ss = avg_pos_outer(Qflat[exact])
+ovl_ff, std_ff = avg_pos_outer(Qflat[~exact])
+# Cross-group overlap
+def cross_overlap(A, B):
+ nA, nB = A.shape[0], B.shape[0]
+ rng = np.random.default_rng(0)
+ pairs = min(nA*nB, 5000)
+ overlaps = []
+ for _ in range(pairs):
+ a = rng.integers(nA); b = rng.integers(nB)
+ M = A[a].T @ B[b]
+ overlaps.append((M**2).sum() / A.shape[-1])
+ return float(np.mean(overlaps)), float(np.std(overlaps))
+ovl_sf, std_sf = cross_overlap(Qflat[exact], Qflat[~exact])
+
+print(f" succ-succ: {ovl_ss:.4f} ± {std_ss:.4f}")
+print(f" fail-fail: {ovl_ff:.4f} ± {std_ff:.4f}")
+print(f" succ-fail: {ovl_sf:.4f} ± {std_sf:.4f}")
+
+# ---- Final ratio of subspace overlap inside vs cross-group ----
+# If succ-succ and fail-fail >> succ-fail, the two groups live in DIFFERENT subspaces.
+# If they're similar, the unstable subspace is shared.
+
+print("\nplots saved to", OUT)