summaryrefslogtreecommitdiff
path: root/research/flossing/analyze_tangent_modes.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/analyze_tangent_modes.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/analyze_tangent_modes.py')
-rw-r--r--research/flossing/analyze_tangent_modes.py143
1 files changed, 143 insertions, 0 deletions
diff --git a/research/flossing/analyze_tangent_modes.py b/research/flossing/analyze_tangent_modes.py
new file mode 100644
index 0000000..5801ad8
--- /dev/null
+++ b/research/flossing/analyze_tangent_modes.py
@@ -0,0 +1,143 @@
+"""(a) lite: For each test sample, save the final tangent basis Q (top-k modes after
+running through the full inference). Compute position/hidden activity profiles per
+mode and compare success vs failure groups.
+"""
+from __future__ import annotations
+import sys, os, yaml, json, time, argparse
+from pathlib import Path
+import numpy as np
+import torch
+
+HRM_DIR = Path("/home/yurenh2/rrm/hrm")
+sys.path.insert(0, str(HRM_DIR))
+from models.hrm.hrm_act_v1 import HierarchicalReasoningModel_ACTV1
+
+
+def load_model(ckpt_root, ckpt_name, device):
+ cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text())
+ arch_cfg = dict(cfg["arch"])
+ train_meta = json.loads((Path(cfg["data_path"]) / "train" / "dataset.json").read_text())
+ arch_cfg.update(batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"],
+ vocab_size=train_meta["vocab_size"],
+ num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], causal=False)
+ model = HierarchicalReasoningModel_ACTV1(arch_cfg)
+ sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True)
+ stripped = {}
+ for k, v in sd.items():
+ nk = k
+ for p in ("_orig_mod.", "model."):
+ if nk.startswith(p): nk = nk[len(p):]
+ stripped[nk] = v
+ model.load_state_dict(stripped, strict=False)
+ model.to(device).eval()
+ return model, cfg, train_meta
+
+
+def jvp_one(f, x, v):
+ return torch.autograd.functional.jvp(f, x, v=v, create_graph=False, strict=False)
+
+
+def run_save_final_Q(model, batch, k_lyap, device, seed):
+ """Run inference with QR-iteration on top-k tangents; return final Q (B, seq, hidden, k)
+ after all ACT steps. Also return exact_correct, predicted_logits.
+ """
+ inner = model.inner
+ cfg = inner.config
+ B = batch["inputs"].shape[0]
+ seq_full = cfg.seq_len + inner.puzzle_emb_len
+ hidden = cfg.hidden_size
+ state_dim = seq_full * hidden
+
+ z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
+ z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
+ seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None)
+ input_embeddings = inner._input_embeddings(batch["inputs"].to(device),
+ batch["puzzle_identifiers"].to(device))
+
+ g = torch.Generator(device=device).manual_seed(seed)
+ Q0 = torch.randn(B, state_dim, k_lyap, device=device, dtype=torch.float32, generator=g)
+ Q, _ = torch.linalg.qr(Q0)
+
+ with torch.enable_grad():
+ for _act in range(cfg.halt_max_steps):
+ zH = z_H.detach(); zL = z_L.detach()
+ for _h in range(cfg.H_cycles):
+ for _l in range(cfg.L_cycles):
+ out = []
+ fx_last = None
+ f = lambda x: inner.L_level(x, zH + input_embeddings, **seq_info)
+ for i in range(k_lyap):
+ v_i = Q[..., i].view_as(zL)
+ fx, Dv = jvp_one(f, zL, v_i)
+ out.append(Dv.reshape(B, state_dim).to(torch.float32))
+ fx_last = fx
+ Q = torch.stack(out, dim=-1)
+ zL = fx_last
+ Q, R = torch.linalg.qr(Q)
+ out = []
+ f = lambda x: inner.H_level(x, zL, **seq_info)
+ for i in range(k_lyap):
+ v_i = Q[..., i].view_as(zH)
+ fx, Dv = jvp_one(f, zH, v_i)
+ out.append(Dv.reshape(B, state_dim).to(torch.float32))
+ fx_last = fx
+ Q = torch.stack(out, dim=-1)
+ zH = fx_last
+ Q, R = torch.linalg.qr(Q)
+ z_H, z_L = zH, zL
+
+ with torch.no_grad():
+ output = inner.lm_head(z_H)[:, inner.puzzle_emb_len:].float()
+ preds = output.argmax(dim=-1)
+ labels = batch["labels"].to(device)
+ mask = labels > 0
+ exact = ((preds == labels) | ~mask).all(dim=-1).cpu().float().numpy()
+
+ Q_final = Q.reshape(B, seq_full, hidden, k_lyap).cpu().float().numpy()
+ return Q_final, exact
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--ckpt-root", required=True)
+ ap.add_argument("--ckpt-name", default="step_26040")
+ ap.add_argument("--n-samples", type=int, default=256)
+ ap.add_argument("--batch-size", type=int, default=32)
+ ap.add_argument("--k-lyap", type=int, default=4)
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--out", default="tangent_modes.npz")
+ args = ap.parse_args()
+ device = "cuda"
+
+ model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device)
+
+ rng = np.random.default_rng(args.seed)
+ data_path = Path(cfg["data_path"])
+ inputs = np.load(data_path / "test" / "all__inputs.npy")
+ labels = np.load(data_path / "test" / "all__labels.npy")
+ pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy")
+ idx = rng.choice(len(inputs), size=args.n_samples, replace=False)
+
+ Q_all = []; exact_all = []
+ t0 = time.time()
+ for s in range(0, args.n_samples, args.batch_size):
+ e = min(s + args.batch_size, args.n_samples)
+ bidx = idx[s:e]
+ batch = {
+ "inputs": torch.from_numpy(inputs[bidx].astype(np.int32)),
+ "labels": torch.from_numpy(labels[bidx].astype(np.int32)),
+ "puzzle_identifiers": torch.from_numpy(pid[bidx].astype(np.int32)),
+ }
+ Q_final, exact = run_save_final_Q(model, batch, args.k_lyap, device, seed=args.seed + s)
+ Q_all.append(Q_final); exact_all.append(exact)
+ print(f" [{e}/{args.n_samples}] dt={time.time()-t0:.1f}s exact={exact.mean():.3f}", flush=True)
+
+ Q_all = np.concatenate(Q_all, axis=0) # (N, seq, hidden, k)
+ exact_all = np.concatenate(exact_all, axis=0) # (N,)
+ print(f"saved shape Q={Q_all.shape}, exact={exact_all.shape}, acc={exact_all.mean():.3f}")
+ np.savez_compressed(args.out, Q_final=Q_all, exact_correct=exact_all, sample_idx=idx)
+ print(f"saved → {args.out}")
+
+
+if __name__ == "__main__":
+ main()