summaryrefslogtreecommitdiff
path: root/research/flossing/diagnose_hrm_joint.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/diagnose_hrm_joint.py')
-rw-r--r--research/flossing/diagnose_hrm_joint.py240
1 files changed, 240 insertions, 0 deletions
diff --git a/research/flossing/diagnose_hrm_joint.py b/research/flossing/diagnose_hrm_joint.py
new file mode 100644
index 0000000..3ba39cf
--- /dev/null
+++ b/research/flossing/diagnose_hrm_joint.py
@@ -0,0 +1,240 @@
+"""HRM Sudoku Lyapunov diagnostic with CORRECTED joint (z_H, z_L) tangent tracking.
+
+Key fix over diagnose_hrm.py:
+ - State is conceptually (z_H, z_L) ∈ R^{2D} where D = seq_full * hidden.
+ - L_level update: z_L_new = layers_L(z_L + z_H + input_embeddings), so
+ v_L_new = J_L · (v_H + v_L), v_H_new = v_H
+ - H_level update: z_H_new = layers_H(z_H + z_L), so
+ v_H_new = J_H · (v_H + v_L), v_L_new = v_L
+ - Each L or H cycle = ONE JVP per tangent column (same cost as before),
+ but operating on the combined tangent v_H + v_L.
+ - Q is (B, 2D, k); QR over the 2D dimension keeps an orthonormal basis.
+"""
+from __future__ import annotations
+import sys, os, yaml, math, argparse, json, time
+from pathlib import Path
+import numpy as np
+import torch
+
+HRM_DIR = Path("/home/yurenh2/rrm/hrm")
+sys.path.insert(0, str(HRM_DIR))
+
+from models.hrm.hrm_act_v1 import HierarchicalReasoningModel_ACTV1
+
+
+def load_model(ckpt_root: Path, ckpt_name: str, device: str):
+ cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text())
+ arch_cfg = dict(cfg["arch"])
+ train_meta = json.loads((Path(cfg["data_path"]) / "train" / "dataset.json").read_text())
+ arch_cfg.update(batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"],
+ vocab_size=train_meta["vocab_size"],
+ num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], causal=False)
+ model = HierarchicalReasoningModel_ACTV1(arch_cfg)
+ sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True)
+ stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in sd.items()}
+ model.load_state_dict(stripped, strict=False)
+ model.to(device).eval()
+ return model, cfg, train_meta
+
+
+def load_test_samples(data_path: Path, n_total: int, shard_id: int, num_shards: int, seed: int):
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "test" / "all__inputs.npy")
+ labels = np.load(data_path / "test" / "all__labels.npy")
+ pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy")
+ all_idx = rng.choice(len(inputs), size=n_total, replace=False)
+ shard_size = (n_total + num_shards - 1) // num_shards
+ s, e = shard_id * shard_size, min((shard_id + 1) * shard_size, n_total)
+ idx = all_idx[s:e]
+ return {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)),
+ "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)),
+ "idx": idx,
+ }
+
+
+def jvp_through(f, x, v):
+ """One JVP. Returns (f(x), D_f(x) @ v). create_graph=False since this is diagnostic."""
+ return torch.autograd.functional.jvp(f, x, v=v, create_graph=False, strict=False)
+
+
+def run_diagnose_batch(model, batch, device, k_lyap, t_ons, seed):
+ """Compute joint top-k Lyapunov spectrum over (z_H, z_L) joint tangent.
+
+ Per L_level step:
+ v_L_new = J_L · (v_H + v_L), v_H_new = v_H
+ Per H_level step:
+ v_H_new = J_H · (v_H + v_L), v_L_new = v_L
+ """
+ inner = model.inner
+ cfg = inner.config
+ B = batch["inputs"].shape[0]
+ seq_full = cfg.seq_len + inner.puzzle_emb_len
+ hidden = cfg.hidden_size
+ state_dim = seq_full * hidden # one of (z_H or z_L)
+ total_dim = 2 * state_dim # joint (v_H, v_L)
+
+ # Carry init
+ z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
+ z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
+ seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None)
+ input_embeddings = inner._input_embeddings(batch["inputs"].to(device),
+ batch["puzzle_identifiers"].to(device))
+
+ # Joint orthonormal tangent basis
+ g = torch.Generator(device=device).manual_seed(seed)
+ Q0 = torch.randn(B, total_dim, k_lyap, device=device, dtype=torch.float32, generator=g)
+ Q, _ = torch.linalg.qr(Q0) # (B, 2D, k)
+ log_R_sum = torch.zeros(B, k_lyap, device=device, dtype=torch.float32)
+ n_lyap_steps = 0
+ step_counter = 0
+
+ drift_zH_per_step, drift_zL_per_step = [], []
+ halted_at = torch.zeros(B, dtype=torch.long, device=device)
+ q_halt_hist, q_continue_hist = [], []
+
+ for act_step in range(cfg.halt_max_steps):
+ z_H_prev = z_H.detach().clone()
+ z_L_prev = z_L.detach().clone()
+
+ with torch.enable_grad():
+ zH, zL = z_H.detach(), z_L.detach()
+ for _h in range(cfg.H_cycles):
+ for _l in range(cfg.L_cycles):
+ # --- joint tangent: prep v_combined = v_H + v_L ---
+ v_H_all = Q[:, :state_dim, :] # (B, D, k)
+ v_L_all = Q[:, state_dim:, :]
+ v_comb = v_H_all + v_L_all
+ # --- k JVPs through L_level ---
+ new_v_L_cols = []
+ f_L = lambda z: inner.L_level(z, zH + input_embeddings, **seq_info)
+ for i in range(k_lyap):
+ v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype)
+ zL_new, Dv = jvp_through(f_L, zL, v_i)
+ new_v_L_cols.append(Dv.reshape(B, state_dim).to(torch.float32))
+ new_v_L = torch.stack(new_v_L_cols, dim=-1) # (B, D, k)
+ # Reassemble Q (v_H unchanged, v_L updated)
+ Q = torch.cat([v_H_all, new_v_L], dim=1)
+ zL = zL_new
+ step_counter += 1
+ if step_counter % t_ons == 0:
+ Q, R = torch.linalg.qr(Q)
+ log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log()
+ n_lyap_steps += 1
+
+ # --- H step: v_comb = v_H + v_L, JVP through H_level ---
+ v_H_all = Q[:, :state_dim, :]
+ v_L_all = Q[:, state_dim:, :]
+ v_comb = v_H_all + v_L_all
+ new_v_H_cols = []
+ f_H = lambda z: inner.H_level(z, zL, **seq_info)
+ for i in range(k_lyap):
+ v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype)
+ zH_new, Dv = jvp_through(f_H, zH, v_i)
+ new_v_H_cols.append(Dv.reshape(B, state_dim).to(torch.float32))
+ new_v_H = torch.stack(new_v_H_cols, dim=-1)
+ Q = torch.cat([new_v_H, v_L_all], dim=1)
+ zH = zH_new
+ step_counter += 1
+ if step_counter % t_ons == 0:
+ Q, R = torch.linalg.qr(Q)
+ log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log()
+ n_lyap_steps += 1
+
+ z_H, z_L = zH, zL
+
+ drift_zH_per_step.append((z_H - z_H_prev).float().flatten(1).norm(dim=1).cpu())
+ drift_zL_per_step.append((z_L - z_L_prev).float().flatten(1).norm(dim=1).cpu())
+
+ with torch.no_grad():
+ q_logits = inner.q_head(z_H[:, 0]).float()
+ q_halt, q_continue = q_logits[..., 0], q_logits[..., 1]
+ q_halt_hist.append(q_halt.cpu()); q_continue_hist.append(q_continue.cpu())
+ new_halt = (q_halt > q_continue) & (halted_at == 0)
+ halted_at[new_halt] = act_step + 1
+ output = inner.lm_head(z_H)[:, inner.puzzle_emb_len:].float()
+ final_logits = output
+
+ lyap_spec = (log_R_sum / max(n_lyap_steps, 1)).cpu().numpy() # (B, k)
+
+ with torch.no_grad():
+ preds = final_logits.argmax(dim=-1)
+ labels = batch["labels"].to(device)
+ mask = labels > 0
+ exact = ((preds == labels) | ~mask).all(dim=-1).cpu().float()
+ token_acc = ((preds == labels) & mask).sum(-1).float() / mask.sum(-1).float().clamp_min(1)
+ token_acc = token_acc.cpu()
+
+ return {
+ "drift_zH": torch.stack(drift_zH_per_step, dim=1).numpy(),
+ "drift_zL": torch.stack(drift_zL_per_step, dim=1).numpy(),
+ "halted_at": halted_at.cpu().numpy(),
+ "q_halt": torch.stack(q_halt_hist, dim=1).numpy(),
+ "q_continue": torch.stack(q_continue_hist, dim=1).numpy(),
+ "lyap_spec": lyap_spec,
+ "exact_correct": exact.numpy(),
+ "token_acc": token_acc.numpy(),
+ }
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--ckpt-root", required=True)
+ ap.add_argument("--ckpt-name", default="step_26040")
+ ap.add_argument("--n-samples", type=int, default=1024)
+ ap.add_argument("--shard-id", type=int, default=0)
+ ap.add_argument("--num-shards", type=int, default=1)
+ ap.add_argument("--batch-size", type=int, default=32)
+ ap.add_argument("--k-lyap", type=int, default=8)
+ ap.add_argument("--t-ons", type=int, default=1)
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--out", default="diag_joint.npz")
+ args = ap.parse_args()
+
+ device = "cuda"
+ model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device)
+ print(f"loaded {args.ckpt_name}: hidden={model.inner.config.hidden_size}, "
+ f"seq_full={train_meta['seq_len'] + model.inner.puzzle_emb_len}, "
+ f"halt_max_steps={model.inner.config.halt_max_steps}, "
+ f"H={model.inner.config.H_cycles} L={model.inner.config.L_cycles}")
+
+ test_samples = load_test_samples(Path(cfg["data_path"]), args.n_samples,
+ args.shard_id, args.num_shards, args.seed)
+ n_this = len(test_samples["inputs"])
+ print(f"shard {args.shard_id}/{args.num_shards}: {n_this} samples")
+
+ results = {k: [] for k in ["drift_zH","drift_zL","halted_at","q_halt","q_continue",
+ "lyap_spec","exact_correct","token_acc","idx"]}
+ t0 = time.time()
+ for s in range(0, n_this, args.batch_size):
+ e = min(s + args.batch_size, n_this)
+ batch = {k: test_samples[k][s:e].to(device) for k in ["inputs","labels","puzzle_identifiers"]}
+ out = run_diagnose_batch(model, batch, device, args.k_lyap, args.t_ons, args.seed + s)
+ for k, v in out.items():
+ if v is not None: results[k].append(v)
+ results["idx"].append(test_samples["idx"][s:e])
+ ls = out["lyap_spec"]
+ print(f" [{e}/{n_this}] dt={time.time()-t0:.1f}s exact={out['exact_correct'].mean():.3f} "
+ f"λ_1={ls[:,0].mean():.4f} λ_{args.k_lyap}={ls[:,-1].mean():.4f}", flush=True)
+
+ saved = {}
+ for k, v in results.items():
+ if not v: continue
+ try: saved[k] = np.concatenate(v, 0)
+ except ValueError: saved[k] = np.stack(v, 0)
+ np.savez_compressed(args.out, **saved)
+
+ ls = saved["lyap_spec"]
+ succ = saved["exact_correct"] > 0.5
+ print(f"\nN={len(saved['exact_correct'])} acc={succ.mean():.4f}")
+ print(f"{'i':>3} {'mean':>10} {'succ':>10} {'fail':>10} {'Δ(f-s)':>10}")
+ for i in range(ls.shape[1]):
+ li = ls[:, i]
+ print(f"{i+1:>3} {li.mean():+10.4f} {li[succ].mean():+10.4f} {li[~succ].mean():+10.4f} "
+ f"{li[~succ].mean()-li[succ].mean():+10.4f}")
+ print(f"\nsaved → {args.out}")
+
+
+if __name__ == "__main__":
+ main()