"""HRM Sudoku Lyapunov / trajectory diagnostic. Loads a trained HRM checkpoint, runs inference on a sample of the test set, records the recursion trajectory (z_H, z_L at every (act_step, h_cycle, l_cycle)), and computes the top Lyapunov exponent of the recursion Jacobian via power iteration with JVP. Splits samples by success / failure and writes a npz. """ from __future__ import annotations import os, sys, yaml, math, argparse, json, time from pathlib import Path import numpy as np import torch import torch.nn.functional as F HRM_DIR = Path("/home/yurenh2/rrm/hrm") sys.path.insert(0, str(HRM_DIR)) from models.hrm.hrm_act_v1 import ( HierarchicalReasoningModel_ACTV1, HierarchicalReasoningModel_ACTV1Config, HierarchicalReasoningModel_ACTV1Carry, HierarchicalReasoningModel_ACTV1InnerCarry, ) def load_model(ckpt_root: Path, ckpt_name: str, device: str = "cuda"): cfg_path = ckpt_root / "all_config.yaml" cfg = yaml.safe_load(cfg_path.read_text()) arch_cfg = cfg["arch"] # Need batch_size, seq_len, vocab_size, num_puzzle_identifiers — read from train metadata train_meta = json.loads((Path(cfg["data_path"]) / "train" / "dataset.json").read_text()) arch_cfg = dict(arch_cfg) arch_cfg["batch_size"] = cfg["global_batch_size"] arch_cfg["seq_len"] = train_meta["seq_len"] arch_cfg["vocab_size"] = train_meta["vocab_size"] arch_cfg["num_puzzle_identifiers"] = train_meta["num_puzzle_identifiers"] arch_cfg["causal"] = False model = HierarchicalReasoningModel_ACTV1(arch_cfg) sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) # Strip torch.compile (`_orig_mod.`) and ACTLossHead wrapper (`model.`) prefixes stripped = {} for k, v in sd.items(): nk = k for prefix in ("_orig_mod.", "model."): if nk.startswith(prefix): nk = nk[len(prefix):] stripped[nk] = v missing, unexpected = model.load_state_dict(stripped, strict=False) if missing or unexpected: print(f"[load] missing={len(missing)} unexpected={len(unexpected)}; " f"sample missing={missing[:3]}, sample unexpected={unexpected[:3]}") model.to(device).eval() return model, cfg, train_meta def load_test_samples(data_path: Path, n_total: int, shard_id: int = 0, num_shards: int = 1, seed: int = 0): """Choose a deterministic set of n_total samples using `seed`, then return shard `shard_id`.""" rng = np.random.default_rng(seed) inputs = np.load(data_path / "test" / "all__inputs.npy") labels = np.load(data_path / "test" / "all__labels.npy") pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy") all_idx = rng.choice(len(inputs), size=n_total, replace=False) shard_size = (n_total + num_shards - 1) // num_shards s, e = shard_id * shard_size, min((shard_id + 1) * shard_size, n_total) idx = all_idx[s:e] return { "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), "labels": torch.from_numpy(labels[idx].astype(np.int32)), "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)), "idx": idx, } def inner_step(inner, z_H, z_L, input_embeddings, seq_info): """One *full* inner forward = H_cycles x L_cycles cycles, exactly mirroring the training-time recursion but with gradient enabled throughout (we need Jacobians). Returns the *new* (z_H, z_L) and a list of intermediate states. The natural unit step we use for Lyapunov is one *L_level* application; the extra H_level update at the end of each H_cycle is also included as a step. """ trajectory = [(z_H.detach().clone(), z_L.detach().clone())] for _ in range(inner.config.H_cycles): for _ in range(inner.config.L_cycles): z_L = inner.L_level(z_L, z_H + input_embeddings, **seq_info) trajectory.append((z_H.detach().clone(), z_L.detach().clone())) z_H = inner.H_level(z_H, z_L, **seq_info) trajectory.append((z_H.detach().clone(), z_L.detach().clone())) return z_H, z_L, trajectory def _flatten(z): """(B, seq, hidden) → (B, seq*hidden).""" return z.reshape(z.shape[0], -1) def _unflatten(v_flat, B, seq, hidden): return v_flat.reshape(B, seq, hidden) def jvp_apply_D(f, x, V): """Compute D_f(x) @ V where V has shape (B, state_dim, k). Returns f(x) (computed once with the LAST tangent), plus stacked Dv with same shape as V. We do k separate JVPs. """ B, state_dim, k = V.shape out_list = [] fx_last = None for i in range(k): v_i = V[..., i].view_as(x) # (B, seq, hidden) fx, Dv = torch.autograd.functional.jvp(f, x, v=v_i, create_graph=False, strict=False) out_list.append(_flatten(Dv).to(torch.float32)) fx_last = fx DV = torch.stack(out_list, dim=-1) # (B, state_dim, k) return fx_last, DV def run_diagnose_batch(model, batch, device, halt_max_steps, compute_lyap=True, k_lyap=8, t_ons=1, seed=0): """Run inference and collect trajectory + top-k Lyapunov for each sample in batch. For Lyapunov: maintain an orthonormal basis Q of size (state_dim, k_lyap) per sample. At each (h_cycle, l_cycle) step we apply D_t (the Jacobian of one L_level or H_level update) via JVP, then QR-reorthonormalize every t_ons steps and accumulate log|R_ii|. λ_i = (1/T) Σ_t log|R_ii(t)|. """ inner = model.inner B = batch["inputs"].shape[0] seq_full = train_meta_seq_full hidden = inner.config.hidden_size state_dim = seq_full * hidden # Initialize carry z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None) input_embeddings = inner._input_embeddings(batch["inputs"].to(device), batch["puzzle_identifiers"].to(device)) # Initialize orthonormal Q basis for top-k Lyapunov if compute_lyap and k_lyap > 0: torch.manual_seed(seed) # Init random Gaussian then QR Q0 = torch.randn(B, state_dim, k_lyap, device=device, dtype=torch.float32) Q, _ = torch.linalg.qr(Q0) # Q: (B, state_dim, k_lyap), orthonormal columns log_R_sum = torch.zeros(B, k_lyap, device=device, dtype=torch.float32) n_lyap_steps = 0 step_counter = 0 drift_zH_per_step = [] drift_zL_per_step = [] halted_at = torch.zeros(B, dtype=torch.long, device=device) q_halt_history = [] final_logits = None for act_step in range(halt_max_steps): z_H_prev = z_H.detach().clone() z_L_prev = z_L.detach().clone() if compute_lyap and k_lyap > 0: with torch.enable_grad(): zH = z_H.detach() zL = z_L.detach() for _h in range(inner.config.H_cycles): for _l in range(inner.config.L_cycles): f = lambda x: inner.L_level(x, zH + input_embeddings, **seq_info) zL_new, DV = jvp_apply_D(f, zL, Q) # DV: (B, state_dim, k) Q = DV # evolved tangent zL = zL_new step_counter += 1 if step_counter % t_ons == 0: Q, R = torch.linalg.qr(Q) # Q (B, state_dim, k), R (B, k, k) log_R_sum += R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() n_lyap_steps += 1 f = lambda x: inner.H_level(x, zL, **seq_info) zH_new, DV = jvp_apply_D(f, zH, Q) Q = DV zH = zH_new step_counter += 1 if step_counter % t_ons == 0: Q, R = torch.linalg.qr(Q) log_R_sum += R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() n_lyap_steps += 1 z_H = zH z_L = zL else: with torch.no_grad(): for _h in range(inner.config.H_cycles): for _l in range(inner.config.L_cycles): z_L = inner.L_level(z_L, z_H + input_embeddings, **seq_info) z_H = inner.H_level(z_H, z_L, **seq_info) drift_zH_per_step.append((z_H - z_H_prev).float().flatten(1).norm(dim=1).cpu()) drift_zL_per_step.append((z_L - z_L_prev).float().flatten(1).norm(dim=1).cpu()) with torch.no_grad(): q_logits = inner.q_head(z_H[:, 0]).float() q_halt = q_logits[..., 0]; q_continue = q_logits[..., 1] q_halt_history.append((q_halt.cpu(), q_continue.cpu())) newly = (q_halt > q_continue) & (halted_at == 0) halted_at[newly] = act_step + 1 output = inner.lm_head(z_H)[:, inner.puzzle_emb_len:].float() final_logits = output lyap_spec = (log_R_sum / max(n_lyap_steps, 1)).cpu().numpy() if (compute_lyap and k_lyap > 0) else None with torch.no_grad(): preds = final_logits.argmax(dim=-1) labels = batch["labels"].to(device) mask = labels > 0 exact = ((preds == labels) | ~mask).all(dim=-1).cpu().float() token_acc = ((preds == labels) & mask).sum(-1).float() / mask.sum(-1).float().clamp_min(1) token_acc = token_acc.cpu() return { "drift_zH": torch.stack(drift_zH_per_step, dim=1).numpy(), "drift_zL": torch.stack(drift_zL_per_step, dim=1).numpy(), "halted_at": halted_at.cpu().numpy(), "q_halt": torch.stack([h[0] for h in q_halt_history], dim=1).numpy(), "q_continue": torch.stack([h[1] for h in q_halt_history], dim=1).numpy(), "lyap_spec": lyap_spec, # (B, k_lyap) "exact_correct": exact.numpy(), "token_acc": token_acc.numpy(), } def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt-root", required=True, help="path containing all_config.yaml and step_X") ap.add_argument("--ckpt-name", default="step_26040") ap.add_argument("--n-samples", type=int, default=5, help="total sample pool") ap.add_argument("--shard-id", type=int, default=0) ap.add_argument("--num-shards", type=int, default=1) ap.add_argument("--batch-size", type=int, default=64) ap.add_argument("--out", default="diagnose_out.npz") ap.add_argument("--seed", type=int, default=0) ap.add_argument("--no-lyap", action="store_true") ap.add_argument("--k-lyap", type=int, default=8, help="top-k Lyapunov exponents to compute") ap.add_argument("--t-ons", type=int, default=1, help="QR reorthonormalization interval") args = ap.parse_args() device = "cuda" print(f"Loading model from {args.ckpt_root}/{args.ckpt_name} ...") model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device) global train_meta_seq_full train_meta_seq_full = train_meta["seq_len"] + model.inner.puzzle_emb_len print(f" hidden={model.inner.config.hidden_size}, seq_full={train_meta_seq_full}, " f"halt_max_steps={model.inner.config.halt_max_steps}, " f"H_cycles={model.inner.config.H_cycles}, L_cycles={model.inner.config.L_cycles}") test_samples = load_test_samples(Path(cfg["data_path"]), args.n_samples, shard_id=args.shard_id, num_shards=args.num_shards, seed=args.seed) n_this_shard = len(test_samples['inputs']) print(f"Loaded shard {args.shard_id}/{args.num_shards}: {n_this_shard} samples") results = {k: [] for k in ["drift_zH","drift_zL","halted_at","q_halt","q_continue", "lyap_spec","exact_correct","token_acc","idx"]} t0 = time.time() for s in range(0, n_this_shard, args.batch_size): e = min(s + args.batch_size, n_this_shard) batch = {k: test_samples[k][s:e].to(device) for k in ["inputs","labels","puzzle_identifiers"]} out = run_diagnose_batch( model, batch, device, halt_max_steps=model.inner.config.halt_max_steps, compute_lyap=not args.no_lyap, k_lyap=args.k_lyap, t_ons=args.t_ons, seed=args.seed + s, ) for k, v in out.items(): if v is not None: results[k].append(v) results["idx"].append(test_samples["idx"][s:e]) lyap_str = (f" lyap_max={out['lyap_spec'][:,0].mean():.4f} " f"lyap_min={out['lyap_spec'][:,-1].mean():.4f}" if out["lyap_spec"] is not None else "") print(f" [{e}/{n_this_shard}] dt={time.time()-t0:.1f}s " f"exact={out['exact_correct'].mean():.3f}{lyap_str}", flush=True) # Stack saved = {} for k, v in results.items(): if not v: continue try: saved[k] = np.concatenate(v, axis=0) except ValueError: saved[k] = np.stack(v, axis=0) np.savez_compressed(args.out, **saved) print(f"saved to {args.out}") print(f"summary:") print(f" N={len(saved['exact_correct'])} acc={saved['exact_correct'].mean():.3f}") if "lyap_spec" in saved: ls = saved["lyap_spec"] # (N, k) succ = saved["exact_correct"] > 0.5 print(f" lyap_spec shape: {ls.shape}") for i in range(ls.shape[1]): li = ls[:, i] print(f" λ_{i+1}: overall={li.mean():+.4f}±{li.std():.4f} " f"succ={li[succ].mean():+.4f} fail={li[~succ].mean():+.4f} " f"Δ={li[~succ].mean()-li[succ].mean():+.4f}") if __name__ == "__main__": main()