summaryrefslogtreecommitdiff
path: root/research/flossing/diagnose_hrm.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/diagnose_hrm.py')
-rw-r--r--research/flossing/diagnose_hrm.py306
1 files changed, 306 insertions, 0 deletions
diff --git a/research/flossing/diagnose_hrm.py b/research/flossing/diagnose_hrm.py
new file mode 100644
index 0000000..193fa41
--- /dev/null
+++ b/research/flossing/diagnose_hrm.py
@@ -0,0 +1,306 @@
+"""HRM Sudoku Lyapunov / trajectory diagnostic.
+
+Loads a trained HRM checkpoint, runs inference on a sample of the test set,
+records the recursion trajectory (z_H, z_L at every (act_step, h_cycle, l_cycle)),
+and computes the top Lyapunov exponent of the recursion Jacobian via power
+iteration with JVP. Splits samples by success / failure and writes a npz.
+"""
+from __future__ import annotations
+import os, sys, yaml, math, argparse, json, time
+from pathlib import Path
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+HRM_DIR = Path("/home/yurenh2/rrm/hrm")
+sys.path.insert(0, str(HRM_DIR))
+
+from models.hrm.hrm_act_v1 import (
+ HierarchicalReasoningModel_ACTV1,
+ HierarchicalReasoningModel_ACTV1Config,
+ HierarchicalReasoningModel_ACTV1Carry,
+ HierarchicalReasoningModel_ACTV1InnerCarry,
+)
+
+
+def load_model(ckpt_root: Path, ckpt_name: str, device: str = "cuda"):
+ cfg_path = ckpt_root / "all_config.yaml"
+ cfg = yaml.safe_load(cfg_path.read_text())
+ arch_cfg = cfg["arch"]
+ # Need batch_size, seq_len, vocab_size, num_puzzle_identifiers — read from train metadata
+ train_meta = json.loads((Path(cfg["data_path"]) / "train" / "dataset.json").read_text())
+ arch_cfg = dict(arch_cfg)
+ arch_cfg["batch_size"] = cfg["global_batch_size"]
+ arch_cfg["seq_len"] = train_meta["seq_len"]
+ arch_cfg["vocab_size"] = train_meta["vocab_size"]
+ arch_cfg["num_puzzle_identifiers"] = train_meta["num_puzzle_identifiers"]
+ arch_cfg["causal"] = False
+ model = HierarchicalReasoningModel_ACTV1(arch_cfg)
+ sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True)
+ # Strip torch.compile (`_orig_mod.`) and ACTLossHead wrapper (`model.`) prefixes
+ stripped = {}
+ for k, v in sd.items():
+ nk = k
+ for prefix in ("_orig_mod.", "model."):
+ if nk.startswith(prefix):
+ nk = nk[len(prefix):]
+ stripped[nk] = v
+ missing, unexpected = model.load_state_dict(stripped, strict=False)
+ if missing or unexpected:
+ print(f"[load] missing={len(missing)} unexpected={len(unexpected)}; "
+ f"sample missing={missing[:3]}, sample unexpected={unexpected[:3]}")
+ model.to(device).eval()
+ return model, cfg, train_meta
+
+
+def load_test_samples(data_path: Path, n_total: int, shard_id: int = 0, num_shards: int = 1, seed: int = 0):
+ """Choose a deterministic set of n_total samples using `seed`, then return shard `shard_id`."""
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "test" / "all__inputs.npy")
+ labels = np.load(data_path / "test" / "all__labels.npy")
+ pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy")
+ all_idx = rng.choice(len(inputs), size=n_total, replace=False)
+ shard_size = (n_total + num_shards - 1) // num_shards
+ s, e = shard_id * shard_size, min((shard_id + 1) * shard_size, n_total)
+ idx = all_idx[s:e]
+ return {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)),
+ "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)),
+ "idx": idx,
+ }
+
+
+def inner_step(inner, z_H, z_L, input_embeddings, seq_info):
+ """One *full* inner forward = H_cycles x L_cycles cycles, exactly mirroring the
+ training-time recursion but with gradient enabled throughout (we need Jacobians).
+ Returns the *new* (z_H, z_L) and a list of intermediate states.
+
+ The natural unit step we use for Lyapunov is one *L_level* application; the
+ extra H_level update at the end of each H_cycle is also included as a step.
+ """
+ trajectory = [(z_H.detach().clone(), z_L.detach().clone())]
+ for _ in range(inner.config.H_cycles):
+ for _ in range(inner.config.L_cycles):
+ z_L = inner.L_level(z_L, z_H + input_embeddings, **seq_info)
+ trajectory.append((z_H.detach().clone(), z_L.detach().clone()))
+ z_H = inner.H_level(z_H, z_L, **seq_info)
+ trajectory.append((z_H.detach().clone(), z_L.detach().clone()))
+ return z_H, z_L, trajectory
+
+
+def _flatten(z):
+ """(B, seq, hidden) → (B, seq*hidden)."""
+ return z.reshape(z.shape[0], -1)
+
+
+def _unflatten(v_flat, B, seq, hidden):
+ return v_flat.reshape(B, seq, hidden)
+
+
+def jvp_apply_D(f, x, V):
+ """Compute D_f(x) @ V where V has shape (B, state_dim, k).
+
+ Returns f(x) (computed once with the LAST tangent), plus stacked Dv with same shape as V.
+ We do k separate JVPs.
+ """
+ B, state_dim, k = V.shape
+ out_list = []
+ fx_last = None
+ for i in range(k):
+ v_i = V[..., i].view_as(x) # (B, seq, hidden)
+ fx, Dv = torch.autograd.functional.jvp(f, x, v=v_i, create_graph=False, strict=False)
+ out_list.append(_flatten(Dv).to(torch.float32))
+ fx_last = fx
+ DV = torch.stack(out_list, dim=-1) # (B, state_dim, k)
+ return fx_last, DV
+
+
+def run_diagnose_batch(model, batch, device, halt_max_steps, compute_lyap=True, k_lyap=8, t_ons=1, seed=0):
+ """Run inference and collect trajectory + top-k Lyapunov for each sample in batch.
+
+ For Lyapunov: maintain an orthonormal basis Q of size (state_dim, k_lyap) per sample.
+ At each (h_cycle, l_cycle) step we apply D_t (the Jacobian of one L_level or H_level
+ update) via JVP, then QR-reorthonormalize every t_ons steps and accumulate log|R_ii|.
+ λ_i = (1/T) Σ_t log|R_ii(t)|.
+ """
+ inner = model.inner
+ B = batch["inputs"].shape[0]
+ seq_full = train_meta_seq_full
+ hidden = inner.config.hidden_size
+ state_dim = seq_full * hidden
+
+ # Initialize carry
+ z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
+ z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
+
+ seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None)
+ input_embeddings = inner._input_embeddings(batch["inputs"].to(device),
+ batch["puzzle_identifiers"].to(device))
+
+ # Initialize orthonormal Q basis for top-k Lyapunov
+ if compute_lyap and k_lyap > 0:
+ torch.manual_seed(seed)
+ # Init random Gaussian then QR
+ Q0 = torch.randn(B, state_dim, k_lyap, device=device, dtype=torch.float32)
+ Q, _ = torch.linalg.qr(Q0) # Q: (B, state_dim, k_lyap), orthonormal columns
+ log_R_sum = torch.zeros(B, k_lyap, device=device, dtype=torch.float32)
+ n_lyap_steps = 0
+ step_counter = 0
+
+ drift_zH_per_step = []
+ drift_zL_per_step = []
+ halted_at = torch.zeros(B, dtype=torch.long, device=device)
+ q_halt_history = []
+
+ final_logits = None
+ for act_step in range(halt_max_steps):
+ z_H_prev = z_H.detach().clone()
+ z_L_prev = z_L.detach().clone()
+
+ if compute_lyap and k_lyap > 0:
+ with torch.enable_grad():
+ zH = z_H.detach()
+ zL = z_L.detach()
+ for _h in range(inner.config.H_cycles):
+ for _l in range(inner.config.L_cycles):
+ f = lambda x: inner.L_level(x, zH + input_embeddings, **seq_info)
+ zL_new, DV = jvp_apply_D(f, zL, Q) # DV: (B, state_dim, k)
+ Q = DV # evolved tangent
+ zL = zL_new
+ step_counter += 1
+ if step_counter % t_ons == 0:
+ Q, R = torch.linalg.qr(Q) # Q (B, state_dim, k), R (B, k, k)
+ log_R_sum += R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log()
+ n_lyap_steps += 1
+ f = lambda x: inner.H_level(x, zL, **seq_info)
+ zH_new, DV = jvp_apply_D(f, zH, Q)
+ Q = DV
+ zH = zH_new
+ step_counter += 1
+ if step_counter % t_ons == 0:
+ Q, R = torch.linalg.qr(Q)
+ log_R_sum += R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log()
+ n_lyap_steps += 1
+ z_H = zH
+ z_L = zL
+ else:
+ with torch.no_grad():
+ for _h in range(inner.config.H_cycles):
+ for _l in range(inner.config.L_cycles):
+ z_L = inner.L_level(z_L, z_H + input_embeddings, **seq_info)
+ z_H = inner.H_level(z_H, z_L, **seq_info)
+
+ drift_zH_per_step.append((z_H - z_H_prev).float().flatten(1).norm(dim=1).cpu())
+ drift_zL_per_step.append((z_L - z_L_prev).float().flatten(1).norm(dim=1).cpu())
+
+ with torch.no_grad():
+ q_logits = inner.q_head(z_H[:, 0]).float()
+ q_halt = q_logits[..., 0]; q_continue = q_logits[..., 1]
+ q_halt_history.append((q_halt.cpu(), q_continue.cpu()))
+ newly = (q_halt > q_continue) & (halted_at == 0)
+ halted_at[newly] = act_step + 1
+ output = inner.lm_head(z_H)[:, inner.puzzle_emb_len:].float()
+ final_logits = output
+
+ lyap_spec = (log_R_sum / max(n_lyap_steps, 1)).cpu().numpy() if (compute_lyap and k_lyap > 0) else None
+
+ with torch.no_grad():
+ preds = final_logits.argmax(dim=-1)
+ labels = batch["labels"].to(device)
+ mask = labels > 0
+ exact = ((preds == labels) | ~mask).all(dim=-1).cpu().float()
+ token_acc = ((preds == labels) & mask).sum(-1).float() / mask.sum(-1).float().clamp_min(1)
+ token_acc = token_acc.cpu()
+
+ return {
+ "drift_zH": torch.stack(drift_zH_per_step, dim=1).numpy(),
+ "drift_zL": torch.stack(drift_zL_per_step, dim=1).numpy(),
+ "halted_at": halted_at.cpu().numpy(),
+ "q_halt": torch.stack([h[0] for h in q_halt_history], dim=1).numpy(),
+ "q_continue": torch.stack([h[1] for h in q_halt_history], dim=1).numpy(),
+ "lyap_spec": lyap_spec, # (B, k_lyap)
+ "exact_correct": exact.numpy(),
+ "token_acc": token_acc.numpy(),
+ }
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--ckpt-root", required=True,
+ help="path containing all_config.yaml and step_X")
+ ap.add_argument("--ckpt-name", default="step_26040")
+ ap.add_argument("--n-samples", type=int, default=5, help="total sample pool")
+ ap.add_argument("--shard-id", type=int, default=0)
+ ap.add_argument("--num-shards", type=int, default=1)
+ ap.add_argument("--batch-size", type=int, default=64)
+ ap.add_argument("--out", default="diagnose_out.npz")
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--no-lyap", action="store_true")
+ ap.add_argument("--k-lyap", type=int, default=8, help="top-k Lyapunov exponents to compute")
+ ap.add_argument("--t-ons", type=int, default=1, help="QR reorthonormalization interval")
+ args = ap.parse_args()
+
+ device = "cuda"
+ print(f"Loading model from {args.ckpt_root}/{args.ckpt_name} ...")
+ model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device)
+ global train_meta_seq_full
+ train_meta_seq_full = train_meta["seq_len"] + model.inner.puzzle_emb_len
+ print(f" hidden={model.inner.config.hidden_size}, seq_full={train_meta_seq_full}, "
+ f"halt_max_steps={model.inner.config.halt_max_steps}, "
+ f"H_cycles={model.inner.config.H_cycles}, L_cycles={model.inner.config.L_cycles}")
+
+ test_samples = load_test_samples(Path(cfg["data_path"]), args.n_samples,
+ shard_id=args.shard_id, num_shards=args.num_shards,
+ seed=args.seed)
+ n_this_shard = len(test_samples['inputs'])
+ print(f"Loaded shard {args.shard_id}/{args.num_shards}: {n_this_shard} samples")
+
+ results = {k: [] for k in ["drift_zH","drift_zL","halted_at","q_halt","q_continue",
+ "lyap_spec","exact_correct","token_acc","idx"]}
+ t0 = time.time()
+ for s in range(0, n_this_shard, args.batch_size):
+ e = min(s + args.batch_size, n_this_shard)
+ batch = {k: test_samples[k][s:e].to(device)
+ for k in ["inputs","labels","puzzle_identifiers"]}
+ out = run_diagnose_batch(
+ model, batch, device,
+ halt_max_steps=model.inner.config.halt_max_steps,
+ compute_lyap=not args.no_lyap, k_lyap=args.k_lyap, t_ons=args.t_ons,
+ seed=args.seed + s,
+ )
+ for k, v in out.items():
+ if v is not None:
+ results[k].append(v)
+ results["idx"].append(test_samples["idx"][s:e])
+ lyap_str = (f" lyap_max={out['lyap_spec'][:,0].mean():.4f} "
+ f"lyap_min={out['lyap_spec'][:,-1].mean():.4f}"
+ if out["lyap_spec"] is not None else "")
+ print(f" [{e}/{n_this_shard}] dt={time.time()-t0:.1f}s "
+ f"exact={out['exact_correct'].mean():.3f}{lyap_str}", flush=True)
+
+ # Stack
+ saved = {}
+ for k, v in results.items():
+ if not v: continue
+ try:
+ saved[k] = np.concatenate(v, axis=0)
+ except ValueError:
+ saved[k] = np.stack(v, axis=0)
+ np.savez_compressed(args.out, **saved)
+ print(f"saved to {args.out}")
+ print(f"summary:")
+ print(f" N={len(saved['exact_correct'])} acc={saved['exact_correct'].mean():.3f}")
+ if "lyap_spec" in saved:
+ ls = saved["lyap_spec"] # (N, k)
+ succ = saved["exact_correct"] > 0.5
+ print(f" lyap_spec shape: {ls.shape}")
+ for i in range(ls.shape[1]):
+ li = ls[:, i]
+ print(f" λ_{i+1}: overall={li.mean():+.4f}±{li.std():.4f} "
+ f"succ={li[succ].mean():+.4f} fail={li[~succ].mean():+.4f} "
+ f"Δ={li[~succ].mean()-li[succ].mean():+.4f}")
+
+
+if __name__ == "__main__":
+ main()