"""(a) lite: For each test sample, save the final tangent basis Q (top-k modes after running through the full inference). Compute position/hidden activity profiles per mode and compare success vs failure groups. """ from __future__ import annotations import sys, os, yaml, json, time, argparse from pathlib import Path import numpy as np import torch HRM_DIR = Path("/home/yurenh2/rrm/hrm") sys.path.insert(0, str(HRM_DIR)) from models.hrm.hrm_act_v1 import HierarchicalReasoningModel_ACTV1 def load_model(ckpt_root, ckpt_name, device): cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) arch_cfg = dict(cfg["arch"]) train_meta = json.loads((Path(cfg["data_path"]) / "train" / "dataset.json").read_text()) arch_cfg.update(batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"], vocab_size=train_meta["vocab_size"], num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], causal=False) model = HierarchicalReasoningModel_ACTV1(arch_cfg) sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) stripped = {} for k, v in sd.items(): nk = k for p in ("_orig_mod.", "model."): if nk.startswith(p): nk = nk[len(p):] stripped[nk] = v model.load_state_dict(stripped, strict=False) model.to(device).eval() return model, cfg, train_meta def jvp_one(f, x, v): return torch.autograd.functional.jvp(f, x, v=v, create_graph=False, strict=False) def run_save_final_Q(model, batch, k_lyap, device, seed): """Run inference with QR-iteration on top-k tangents; return final Q (B, seq, hidden, k) after all ACT steps. Also return exact_correct, predicted_logits. """ inner = model.inner cfg = inner.config B = batch["inputs"].shape[0] seq_full = cfg.seq_len + inner.puzzle_emb_len hidden = cfg.hidden_size state_dim = seq_full * hidden z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None) input_embeddings = inner._input_embeddings(batch["inputs"].to(device), batch["puzzle_identifiers"].to(device)) g = torch.Generator(device=device).manual_seed(seed) Q0 = torch.randn(B, state_dim, k_lyap, device=device, dtype=torch.float32, generator=g) Q, _ = torch.linalg.qr(Q0) with torch.enable_grad(): for _act in range(cfg.halt_max_steps): zH = z_H.detach(); zL = z_L.detach() for _h in range(cfg.H_cycles): for _l in range(cfg.L_cycles): out = [] fx_last = None f = lambda x: inner.L_level(x, zH + input_embeddings, **seq_info) for i in range(k_lyap): v_i = Q[..., i].view_as(zL) fx, Dv = jvp_one(f, zL, v_i) out.append(Dv.reshape(B, state_dim).to(torch.float32)) fx_last = fx Q = torch.stack(out, dim=-1) zL = fx_last Q, R = torch.linalg.qr(Q) out = [] f = lambda x: inner.H_level(x, zL, **seq_info) for i in range(k_lyap): v_i = Q[..., i].view_as(zH) fx, Dv = jvp_one(f, zH, v_i) out.append(Dv.reshape(B, state_dim).to(torch.float32)) fx_last = fx Q = torch.stack(out, dim=-1) zH = fx_last Q, R = torch.linalg.qr(Q) z_H, z_L = zH, zL with torch.no_grad(): output = inner.lm_head(z_H)[:, inner.puzzle_emb_len:].float() preds = output.argmax(dim=-1) labels = batch["labels"].to(device) mask = labels > 0 exact = ((preds == labels) | ~mask).all(dim=-1).cpu().float().numpy() Q_final = Q.reshape(B, seq_full, hidden, k_lyap).cpu().float().numpy() return Q_final, exact def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt-root", required=True) ap.add_argument("--ckpt-name", default="step_26040") ap.add_argument("--n-samples", type=int, default=256) ap.add_argument("--batch-size", type=int, default=32) ap.add_argument("--k-lyap", type=int, default=4) ap.add_argument("--seed", type=int, default=0) ap.add_argument("--out", default="tangent_modes.npz") args = ap.parse_args() device = "cuda" model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device) rng = np.random.default_rng(args.seed) data_path = Path(cfg["data_path"]) inputs = np.load(data_path / "test" / "all__inputs.npy") labels = np.load(data_path / "test" / "all__labels.npy") pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy") idx = rng.choice(len(inputs), size=args.n_samples, replace=False) Q_all = []; exact_all = [] t0 = time.time() for s in range(0, args.n_samples, args.batch_size): e = min(s + args.batch_size, args.n_samples) bidx = idx[s:e] batch = { "inputs": torch.from_numpy(inputs[bidx].astype(np.int32)), "labels": torch.from_numpy(labels[bidx].astype(np.int32)), "puzzle_identifiers": torch.from_numpy(pid[bidx].astype(np.int32)), } Q_final, exact = run_save_final_Q(model, batch, args.k_lyap, device, seed=args.seed + s) Q_all.append(Q_final); exact_all.append(exact) print(f" [{e}/{args.n_samples}] dt={time.time()-t0:.1f}s exact={exact.mean():.3f}", flush=True) Q_all = np.concatenate(Q_all, axis=0) # (N, seq, hidden, k) exact_all = np.concatenate(exact_all, axis=0) # (N,) print(f"saved shape Q={Q_all.shape}, exact={exact_all.shape}, acc={exact_all.mean():.3f}") np.savez_compressed(args.out, Q_final=Q_all, exact_correct=exact_all, sample_idx=idx) print(f"saved → {args.out}") if __name__ == "__main__": main()