"""HRM diagnostic: separate λ_L (固定 z_H 下 L 子系统) and λ_H (z_L 收敛后 H 子系统), in addition to the joint λ from diagnose_hrm_joint.py. Three orthonormal bases evolved in parallel: - Q_joint: (B, 2D, k) — joint (v_H, v_L). Block-matrix update per L/H step. - Q_L: (B, D, k) — only updated during L steps via J_L. Unchanged during H steps. - Q_H: (B, D, k) — only updated during H steps via J_H. Unchanged during L steps. Note: Q_L's evolution uses J_L evaluated at the current trajectory's z_L+z_H+ie, which means we measure the L sub-system Lyapunov "along the actual z_H trajectory" (z_H changes between H-cycles). Similarly for Q_H. """ from __future__ import annotations import sys, os, yaml, math, argparse, json, time from pathlib import Path import numpy as np import torch HRM_DIR = Path("/home/yurenh2/rrm/hrm") sys.path.insert(0, str(HRM_DIR)) from models.hrm.hrm_act_v1 import HierarchicalReasoningModel_ACTV1 def load_model(ckpt_root: Path, ckpt_name: str, device: str): cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) arch_cfg = dict(cfg["arch"]) train_meta = json.loads((Path(cfg["data_path"]) / "train" / "dataset.json").read_text()) arch_cfg.update(batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"], vocab_size=train_meta["vocab_size"], num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], causal=False) model = HierarchicalReasoningModel_ACTV1(arch_cfg) sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in sd.items()} model.load_state_dict(stripped, strict=False) model.to(device).eval() return model, cfg, train_meta def load_test_samples(data_path, n_total, shard_id, num_shards, seed): rng = np.random.default_rng(seed) inputs = np.load(data_path / "test" / "all__inputs.npy") labels = np.load(data_path / "test" / "all__labels.npy") pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy") all_idx = rng.choice(len(inputs), size=n_total, replace=False) shard_size = (n_total + num_shards - 1) // num_shards s, e = shard_id * shard_size, min((shard_id + 1) * shard_size, n_total) idx = all_idx[s:e] return { "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), "labels": torch.from_numpy(labels[idx].astype(np.int32)), "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)), "idx": idx, } def jvp(f, x, v): return torch.autograd.functional.jvp(f, x, v=v, create_graph=False, strict=False) def run_diagnose_batch(model, batch, device, k_lyap, t_ons, seed): inner = model.inner cfg = inner.config B = batch["inputs"].shape[0] seq_full = cfg.seq_len + inner.puzzle_emb_len hidden = cfg.hidden_size D = seq_full * hidden z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None) input_embeddings = inner._input_embeddings(batch["inputs"].to(device), batch["puzzle_identifiers"].to(device)) # Three independent orthonormal bases g = torch.Generator(device=device).manual_seed(seed) Q_joint = torch.linalg.qr(torch.randn(B, 2*D, k_lyap, device=device, dtype=torch.float32, generator=g))[0] Q_L = torch.linalg.qr(torch.randn(B, D, k_lyap, device=device, dtype=torch.float32, generator=g))[0] Q_H = torch.linalg.qr(torch.randn(B, D, k_lyap, device=device, dtype=torch.float32, generator=g))[0] log_R_joint = torch.zeros(B, k_lyap, device=device, dtype=torch.float32) log_R_L = torch.zeros(B, k_lyap, device=device, dtype=torch.float32) log_R_H = torch.zeros(B, k_lyap, device=device, dtype=torch.float32) n_joint_steps = 0; n_L_steps = 0; n_H_steps = 0 step_counter_joint = 0; step_counter_L = 0; step_counter_H = 0 for act_step in range(cfg.halt_max_steps): with torch.enable_grad(): zH, zL = z_H.detach(), z_L.detach() for _h in range(cfg.H_cycles): for _l in range(cfg.L_cycles): # ============ JOINT update (L step) ============ v_H_j = Q_joint[:, :D, :] v_L_j = Q_joint[:, D:, :] v_comb = v_H_j + v_L_j new_v_L_j_cols = [] f_L = lambda z: inner.L_level(z, zH + input_embeddings, **seq_info) for i in range(k_lyap): v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) zL_new, Dv = jvp(f_L, zL, v_i) new_v_L_j_cols.append(Dv.reshape(B, D).to(torch.float32)) new_v_L_j = torch.stack(new_v_L_j_cols, dim=-1) Q_joint = torch.cat([v_H_j, new_v_L_j], dim=1) # ============ L-only update ============ new_v_L_only_cols = [] for i in range(k_lyap): v_i = Q_L[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) _, Dv = jvp(f_L, zL, v_i) new_v_L_only_cols.append(Dv.reshape(B, D).to(torch.float32)) Q_L = torch.stack(new_v_L_only_cols, dim=-1) # Q_H untouched during L step (since H_level wasn't applied) zL = zL_new step_counter_joint += 1; step_counter_L += 1 if step_counter_joint % t_ons == 0: Q_joint, Rj = torch.linalg.qr(Q_joint) log_R_joint = log_R_joint + Rj.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() n_joint_steps += 1 if step_counter_L % t_ons == 0: Q_L, Rl = torch.linalg.qr(Q_L) log_R_L = log_R_L + Rl.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() n_L_steps += 1 # ============ JOINT update (H step) ============ v_H_j = Q_joint[:, :D, :] v_L_j = Q_joint[:, D:, :] v_comb = v_H_j + v_L_j new_v_H_j_cols = [] f_H = lambda z: inner.H_level(z, zL, **seq_info) for i in range(k_lyap): v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) zH_new, Dv = jvp(f_H, zH, v_i) new_v_H_j_cols.append(Dv.reshape(B, D).to(torch.float32)) new_v_H_j = torch.stack(new_v_H_j_cols, dim=-1) Q_joint = torch.cat([new_v_H_j, v_L_j], dim=1) # ============ H-only update ============ new_v_H_only_cols = [] for i in range(k_lyap): v_i = Q_H[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) _, Dv = jvp(f_H, zH, v_i) new_v_H_only_cols.append(Dv.reshape(B, D).to(torch.float32)) Q_H = torch.stack(new_v_H_only_cols, dim=-1) # Q_L untouched during H step zH = zH_new step_counter_joint += 1; step_counter_H += 1 if step_counter_joint % t_ons == 0: Q_joint, Rj = torch.linalg.qr(Q_joint) log_R_joint = log_R_joint + Rj.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() n_joint_steps += 1 if step_counter_H % t_ons == 0: Q_H, Rh = torch.linalg.qr(Q_H) log_R_H = log_R_H + Rh.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() n_H_steps += 1 z_H, z_L = zH, zL with torch.no_grad(): output = inner.lm_head(z_H)[:, inner.puzzle_emb_len:].float() final_logits = output lyap_joint = (log_R_joint / max(n_joint_steps, 1)).cpu().numpy() lyap_L = (log_R_L / max(n_L_steps, 1)).cpu().numpy() lyap_H = (log_R_H / max(n_H_steps, 1)).cpu().numpy() with torch.no_grad(): preds = final_logits.argmax(dim=-1) labels = batch["labels"].to(device) mask = labels > 0 exact = ((preds == labels) | ~mask).all(dim=-1).cpu().float() token_acc = ((preds == labels) & mask).sum(-1).float() / mask.sum(-1).float().clamp_min(1) token_acc = token_acc.cpu() return { "lyap_joint": lyap_joint, "lyap_L": lyap_L, "lyap_H": lyap_H, "exact_correct": exact.numpy(), "token_acc": token_acc.numpy(), } def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt-root", required=True) ap.add_argument("--ckpt-name", default="step_26040") ap.add_argument("--n-samples", type=int, default=1024) ap.add_argument("--shard-id", type=int, default=0) ap.add_argument("--num-shards", type=int, default=1) ap.add_argument("--batch-size", type=int, default=32) ap.add_argument("--k-lyap", type=int, default=8) ap.add_argument("--t-ons", type=int, default=1) ap.add_argument("--seed", type=int, default=0) ap.add_argument("--out", default="diag_separate.npz") args = ap.parse_args() device = "cuda" model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device) test = load_test_samples(Path(cfg["data_path"]), args.n_samples, args.shard_id, args.num_shards, args.seed) n = len(test["inputs"]) print(f"shard {args.shard_id}/{args.num_shards}: {n} samples") print(f"H_cycles={model.inner.config.H_cycles} L_cycles={model.inner.config.L_cycles} halt={model.inner.config.halt_max_steps}") res = {k: [] for k in ["lyap_joint","lyap_L","lyap_H","exact_correct","token_acc","idx"]} t0 = time.time() for s in range(0, n, args.batch_size): e = min(s + args.batch_size, n) batch = {k: test[k][s:e].to(device) for k in ["inputs","labels","puzzle_identifiers"]} out = run_diagnose_batch(model, batch, device, args.k_lyap, args.t_ons, args.seed + s) for k, v in out.items(): res[k].append(v) res["idx"].append(test["idx"][s:e]) print(f" [{e}/{n}] dt={time.time()-t0:.1f}s exact={out['exact_correct'].mean():.3f} " f"λj1={out['lyap_joint'][:,0].mean():+.3f} " f"λL1={out['lyap_L'][:,0].mean():+.3f} " f"λH1={out['lyap_H'][:,0].mean():+.3f}", flush=True) saved = {} for k, v in res.items(): if not v: continue try: saved[k] = np.concatenate(v, 0) except ValueError: saved[k] = np.stack(v, 0) np.savez_compressed(args.out, **saved) succ = saved["exact_correct"] > 0.5 print(f"\nN={len(succ)} acc={succ.mean():.4f}") for name in ["lyap_joint", "lyap_L", "lyap_H"]: ls = saved[name] print(f"\n{name}:") print(f" i mean_succ mean_fail Δ") for i in range(ls.shape[1]): ms, mf = ls[succ,i].mean(), ls[~succ,i].mean() print(f" {i+1} {ms:+8.4f} {mf:+8.4f} {mf-ms:+8.4f}") print(f"\nsaved → {args.out}") if __name__ == "__main__": main()