diff options
Diffstat (limited to 'research/flossing/diagnose_hrm.py')
| -rw-r--r-- | research/flossing/diagnose_hrm.py | 306 |
1 files changed, 306 insertions, 0 deletions
diff --git a/research/flossing/diagnose_hrm.py b/research/flossing/diagnose_hrm.py new file mode 100644 index 0000000..193fa41 --- /dev/null +++ b/research/flossing/diagnose_hrm.py @@ -0,0 +1,306 @@ +"""HRM Sudoku Lyapunov / trajectory diagnostic. + +Loads a trained HRM checkpoint, runs inference on a sample of the test set, +records the recursion trajectory (z_H, z_L at every (act_step, h_cycle, l_cycle)), +and computes the top Lyapunov exponent of the recursion Jacobian via power +iteration with JVP. Splits samples by success / failure and writes a npz. +""" +from __future__ import annotations +import os, sys, yaml, math, argparse, json, time +from pathlib import Path +import numpy as np +import torch +import torch.nn.functional as F + +HRM_DIR = Path("/home/yurenh2/rrm/hrm") +sys.path.insert(0, str(HRM_DIR)) + +from models.hrm.hrm_act_v1 import ( + HierarchicalReasoningModel_ACTV1, + HierarchicalReasoningModel_ACTV1Config, + HierarchicalReasoningModel_ACTV1Carry, + HierarchicalReasoningModel_ACTV1InnerCarry, +) + + +def load_model(ckpt_root: Path, ckpt_name: str, device: str = "cuda"): + cfg_path = ckpt_root / "all_config.yaml" + cfg = yaml.safe_load(cfg_path.read_text()) + arch_cfg = cfg["arch"] + # Need batch_size, seq_len, vocab_size, num_puzzle_identifiers — read from train metadata + train_meta = json.loads((Path(cfg["data_path"]) / "train" / "dataset.json").read_text()) + arch_cfg = dict(arch_cfg) + arch_cfg["batch_size"] = cfg["global_batch_size"] + arch_cfg["seq_len"] = train_meta["seq_len"] + arch_cfg["vocab_size"] = train_meta["vocab_size"] + arch_cfg["num_puzzle_identifiers"] = train_meta["num_puzzle_identifiers"] + arch_cfg["causal"] = False + model = HierarchicalReasoningModel_ACTV1(arch_cfg) + sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) + # Strip torch.compile (`_orig_mod.`) and ACTLossHead wrapper (`model.`) prefixes + stripped = {} + for k, v in sd.items(): + nk = k + for prefix in ("_orig_mod.", "model."): + if nk.startswith(prefix): + nk = nk[len(prefix):] + stripped[nk] = v + missing, unexpected = model.load_state_dict(stripped, strict=False) + if missing or unexpected: + print(f"[load] missing={len(missing)} unexpected={len(unexpected)}; " + f"sample missing={missing[:3]}, sample unexpected={unexpected[:3]}") + model.to(device).eval() + return model, cfg, train_meta + + +def load_test_samples(data_path: Path, n_total: int, shard_id: int = 0, num_shards: int = 1, seed: int = 0): + """Choose a deterministic set of n_total samples using `seed`, then return shard `shard_id`.""" + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "test" / "all__inputs.npy") + labels = np.load(data_path / "test" / "all__labels.npy") + pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy") + all_idx = rng.choice(len(inputs), size=n_total, replace=False) + shard_size = (n_total + num_shards - 1) // num_shards + s, e = shard_id * shard_size, min((shard_id + 1) * shard_size, n_total) + idx = all_idx[s:e] + return { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), + "labels": torch.from_numpy(labels[idx].astype(np.int32)), + "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)), + "idx": idx, + } + + +def inner_step(inner, z_H, z_L, input_embeddings, seq_info): + """One *full* inner forward = H_cycles x L_cycles cycles, exactly mirroring the + training-time recursion but with gradient enabled throughout (we need Jacobians). + Returns the *new* (z_H, z_L) and a list of intermediate states. + + The natural unit step we use for Lyapunov is one *L_level* application; the + extra H_level update at the end of each H_cycle is also included as a step. + """ + trajectory = [(z_H.detach().clone(), z_L.detach().clone())] + for _ in range(inner.config.H_cycles): + for _ in range(inner.config.L_cycles): + z_L = inner.L_level(z_L, z_H + input_embeddings, **seq_info) + trajectory.append((z_H.detach().clone(), z_L.detach().clone())) + z_H = inner.H_level(z_H, z_L, **seq_info) + trajectory.append((z_H.detach().clone(), z_L.detach().clone())) + return z_H, z_L, trajectory + + +def _flatten(z): + """(B, seq, hidden) → (B, seq*hidden).""" + return z.reshape(z.shape[0], -1) + + +def _unflatten(v_flat, B, seq, hidden): + return v_flat.reshape(B, seq, hidden) + + +def jvp_apply_D(f, x, V): + """Compute D_f(x) @ V where V has shape (B, state_dim, k). + + Returns f(x) (computed once with the LAST tangent), plus stacked Dv with same shape as V. + We do k separate JVPs. + """ + B, state_dim, k = V.shape + out_list = [] + fx_last = None + for i in range(k): + v_i = V[..., i].view_as(x) # (B, seq, hidden) + fx, Dv = torch.autograd.functional.jvp(f, x, v=v_i, create_graph=False, strict=False) + out_list.append(_flatten(Dv).to(torch.float32)) + fx_last = fx + DV = torch.stack(out_list, dim=-1) # (B, state_dim, k) + return fx_last, DV + + +def run_diagnose_batch(model, batch, device, halt_max_steps, compute_lyap=True, k_lyap=8, t_ons=1, seed=0): + """Run inference and collect trajectory + top-k Lyapunov for each sample in batch. + + For Lyapunov: maintain an orthonormal basis Q of size (state_dim, k_lyap) per sample. + At each (h_cycle, l_cycle) step we apply D_t (the Jacobian of one L_level or H_level + update) via JVP, then QR-reorthonormalize every t_ons steps and accumulate log|R_ii|. + λ_i = (1/T) Σ_t log|R_ii(t)|. + """ + inner = model.inner + B = batch["inputs"].shape[0] + seq_full = train_meta_seq_full + hidden = inner.config.hidden_size + state_dim = seq_full * hidden + + # Initialize carry + z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + + seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None) + input_embeddings = inner._input_embeddings(batch["inputs"].to(device), + batch["puzzle_identifiers"].to(device)) + + # Initialize orthonormal Q basis for top-k Lyapunov + if compute_lyap and k_lyap > 0: + torch.manual_seed(seed) + # Init random Gaussian then QR + Q0 = torch.randn(B, state_dim, k_lyap, device=device, dtype=torch.float32) + Q, _ = torch.linalg.qr(Q0) # Q: (B, state_dim, k_lyap), orthonormal columns + log_R_sum = torch.zeros(B, k_lyap, device=device, dtype=torch.float32) + n_lyap_steps = 0 + step_counter = 0 + + drift_zH_per_step = [] + drift_zL_per_step = [] + halted_at = torch.zeros(B, dtype=torch.long, device=device) + q_halt_history = [] + + final_logits = None + for act_step in range(halt_max_steps): + z_H_prev = z_H.detach().clone() + z_L_prev = z_L.detach().clone() + + if compute_lyap and k_lyap > 0: + with torch.enable_grad(): + zH = z_H.detach() + zL = z_L.detach() + for _h in range(inner.config.H_cycles): + for _l in range(inner.config.L_cycles): + f = lambda x: inner.L_level(x, zH + input_embeddings, **seq_info) + zL_new, DV = jvp_apply_D(f, zL, Q) # DV: (B, state_dim, k) + Q = DV # evolved tangent + zL = zL_new + step_counter += 1 + if step_counter % t_ons == 0: + Q, R = torch.linalg.qr(Q) # Q (B, state_dim, k), R (B, k, k) + log_R_sum += R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() + n_lyap_steps += 1 + f = lambda x: inner.H_level(x, zL, **seq_info) + zH_new, DV = jvp_apply_D(f, zH, Q) + Q = DV + zH = zH_new + step_counter += 1 + if step_counter % t_ons == 0: + Q, R = torch.linalg.qr(Q) + log_R_sum += R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() + n_lyap_steps += 1 + z_H = zH + z_L = zL + else: + with torch.no_grad(): + for _h in range(inner.config.H_cycles): + for _l in range(inner.config.L_cycles): + z_L = inner.L_level(z_L, z_H + input_embeddings, **seq_info) + z_H = inner.H_level(z_H, z_L, **seq_info) + + drift_zH_per_step.append((z_H - z_H_prev).float().flatten(1).norm(dim=1).cpu()) + drift_zL_per_step.append((z_L - z_L_prev).float().flatten(1).norm(dim=1).cpu()) + + with torch.no_grad(): + q_logits = inner.q_head(z_H[:, 0]).float() + q_halt = q_logits[..., 0]; q_continue = q_logits[..., 1] + q_halt_history.append((q_halt.cpu(), q_continue.cpu())) + newly = (q_halt > q_continue) & (halted_at == 0) + halted_at[newly] = act_step + 1 + output = inner.lm_head(z_H)[:, inner.puzzle_emb_len:].float() + final_logits = output + + lyap_spec = (log_R_sum / max(n_lyap_steps, 1)).cpu().numpy() if (compute_lyap and k_lyap > 0) else None + + with torch.no_grad(): + preds = final_logits.argmax(dim=-1) + labels = batch["labels"].to(device) + mask = labels > 0 + exact = ((preds == labels) | ~mask).all(dim=-1).cpu().float() + token_acc = ((preds == labels) & mask).sum(-1).float() / mask.sum(-1).float().clamp_min(1) + token_acc = token_acc.cpu() + + return { + "drift_zH": torch.stack(drift_zH_per_step, dim=1).numpy(), + "drift_zL": torch.stack(drift_zL_per_step, dim=1).numpy(), + "halted_at": halted_at.cpu().numpy(), + "q_halt": torch.stack([h[0] for h in q_halt_history], dim=1).numpy(), + "q_continue": torch.stack([h[1] for h in q_halt_history], dim=1).numpy(), + "lyap_spec": lyap_spec, # (B, k_lyap) + "exact_correct": exact.numpy(), + "token_acc": token_acc.numpy(), + } + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--ckpt-root", required=True, + help="path containing all_config.yaml and step_X") + ap.add_argument("--ckpt-name", default="step_26040") + ap.add_argument("--n-samples", type=int, default=5, help="total sample pool") + ap.add_argument("--shard-id", type=int, default=0) + ap.add_argument("--num-shards", type=int, default=1) + ap.add_argument("--batch-size", type=int, default=64) + ap.add_argument("--out", default="diagnose_out.npz") + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--no-lyap", action="store_true") + ap.add_argument("--k-lyap", type=int, default=8, help="top-k Lyapunov exponents to compute") + ap.add_argument("--t-ons", type=int, default=1, help="QR reorthonormalization interval") + args = ap.parse_args() + + device = "cuda" + print(f"Loading model from {args.ckpt_root}/{args.ckpt_name} ...") + model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device) + global train_meta_seq_full + train_meta_seq_full = train_meta["seq_len"] + model.inner.puzzle_emb_len + print(f" hidden={model.inner.config.hidden_size}, seq_full={train_meta_seq_full}, " + f"halt_max_steps={model.inner.config.halt_max_steps}, " + f"H_cycles={model.inner.config.H_cycles}, L_cycles={model.inner.config.L_cycles}") + + test_samples = load_test_samples(Path(cfg["data_path"]), args.n_samples, + shard_id=args.shard_id, num_shards=args.num_shards, + seed=args.seed) + n_this_shard = len(test_samples['inputs']) + print(f"Loaded shard {args.shard_id}/{args.num_shards}: {n_this_shard} samples") + + results = {k: [] for k in ["drift_zH","drift_zL","halted_at","q_halt","q_continue", + "lyap_spec","exact_correct","token_acc","idx"]} + t0 = time.time() + for s in range(0, n_this_shard, args.batch_size): + e = min(s + args.batch_size, n_this_shard) + batch = {k: test_samples[k][s:e].to(device) + for k in ["inputs","labels","puzzle_identifiers"]} + out = run_diagnose_batch( + model, batch, device, + halt_max_steps=model.inner.config.halt_max_steps, + compute_lyap=not args.no_lyap, k_lyap=args.k_lyap, t_ons=args.t_ons, + seed=args.seed + s, + ) + for k, v in out.items(): + if v is not None: + results[k].append(v) + results["idx"].append(test_samples["idx"][s:e]) + lyap_str = (f" lyap_max={out['lyap_spec'][:,0].mean():.4f} " + f"lyap_min={out['lyap_spec'][:,-1].mean():.4f}" + if out["lyap_spec"] is not None else "") + print(f" [{e}/{n_this_shard}] dt={time.time()-t0:.1f}s " + f"exact={out['exact_correct'].mean():.3f}{lyap_str}", flush=True) + + # Stack + saved = {} + for k, v in results.items(): + if not v: continue + try: + saved[k] = np.concatenate(v, axis=0) + except ValueError: + saved[k] = np.stack(v, axis=0) + np.savez_compressed(args.out, **saved) + print(f"saved to {args.out}") + print(f"summary:") + print(f" N={len(saved['exact_correct'])} acc={saved['exact_correct'].mean():.3f}") + if "lyap_spec" in saved: + ls = saved["lyap_spec"] # (N, k) + succ = saved["exact_correct"] > 0.5 + print(f" lyap_spec shape: {ls.shape}") + for i in range(ls.shape[1]): + li = ls[:, i] + print(f" λ_{i+1}: overall={li.mean():+.4f}±{li.std():.4f} " + f"succ={li[succ].mean():+.4f} fail={li[~succ].mean():+.4f} " + f"Δ={li[~succ].mean()-li[succ].mean():+.4f}") + + +if __name__ == "__main__": + main() |
