diff options
Diffstat (limited to 'diagnose_trm_joint_clv.py')
| -rw-r--r-- | diagnose_trm_joint_clv.py | 304 |
1 files changed, 304 insertions, 0 deletions
diff --git a/diagnose_trm_joint_clv.py b/diagnose_trm_joint_clv.py new file mode 100644 index 0000000..f4c939e --- /dev/null +++ b/diagnose_trm_joint_clv.py @@ -0,0 +1,304 @@ +"""TRM Sudoku joint Lyapunov diagnostic — TRM version of diagnose_hrm_joint.py. + +Key differences from HRM: +- TRM has ONE shared L_level (H_layers config is "ignored") +- z_L update: z_L = L_level(z_L, z_H + input_embeddings) +- z_H update: z_H = L_level(z_H, z_L) ← same L_level! +- H_cycles=3, L_cycles=6 (vs HRM 2,2) + +Joint tangent block structure: +- L step: v_L_new = J · (v_L + v_H), v_H_new = v_H, J at (z_L + z_H + ie) +- H step: v_H_new = J' · (v_H + v_L), v_L_new = v_L, J' at (z_H + z_L) +J and J' share weights but evaluated at different points. +""" +from __future__ import annotations +import sys, os, yaml, math, argparse, json, time +from pathlib import Path +import numpy as np +import torch + +TRM_DIR = Path("/home/yurenh2/rrm/trm") +sys.path.insert(0, str(TRM_DIR)) + +from models.recursive_reasoning.trm import TinyRecursiveReasoningModel_ACTV1 + + +def load_model(ckpt_root: Path, ckpt_name: str, device: str): + cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) + arch_cfg = dict(cfg["arch"]) + train_meta = json.loads((Path(cfg["data_paths"][0]) / "train" / "dataset.json").read_text()) + arch_cfg.update(batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"], + vocab_size=train_meta["vocab_size"], + num_puzzle_identifiers=train_meta["num_puzzle_identifiers"]) + model = TinyRecursiveReasoningModel_ACTV1(arch_cfg) + sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) + stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in sd.items()} + missing, unexpected = model.load_state_dict(stripped, strict=False) + print(f"[load] missing={len(missing)} unexpected={len(unexpected)}") + if missing[:3]: print(f" sample missing: {missing[:3]}") + if unexpected[:3]: print(f" sample unexpected: {unexpected[:3]}") + model.to(device).eval() + return model, cfg, train_meta + + +def load_test_samples(data_path, n_total, shard_id, num_shards, seed): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "test" / "all__inputs.npy") + labels = np.load(data_path / "test" / "all__labels.npy") + pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy") + all_idx = rng.choice(len(inputs), size=n_total, replace=False) + shard_size = (n_total + num_shards - 1) // num_shards + s, e = shard_id * shard_size, min((shard_id + 1) * shard_size, n_total) + idx = all_idx[s:e] + return { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), + "labels": torch.from_numpy(labels[idx].astype(np.int32)), + "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)), + "idx": idx, + } + + +def jvp_through(f, x, v): + return torch.autograd.functional.jvp(f, x, v=v, create_graph=False, strict=False) + + +def run_diagnose_batch(model, batch, device, k_lyap, t_ons, seed): + inner = model.inner + cfg = inner.config + B = batch["inputs"].shape[0] + seq_full = cfg.seq_len + inner.puzzle_emb_len + hidden = cfg.hidden_size + D = seq_full * hidden + + z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None) + input_embeddings = inner._input_embeddings(batch["inputs"].to(device), + batch["puzzle_identifiers"].to(device)) + + g = torch.Generator(device=device).manual_seed(seed) + Q0 = torch.randn(B, 2*D, k_lyap, device=device, dtype=torch.float32, generator=g) + Q, _ = torch.linalg.qr(Q0) + log_R_sum = torch.zeros(B, k_lyap, device=device, dtype=torch.float32) + n_lyap_steps = 0 + step_counter = 0 + + drift_zH_per_step, drift_zL_per_step = [], [] + halted_at = torch.zeros(B, dtype=torch.long, device=device) + q_halt_hist, q_continue_hist = [], [] + + # Ginelli CLV: store R at every QR, capture Q at the trajectory midpoint (analysis time). + R_list = [] + Q_mid = None + target_qr = (cfg.halt_max_steps * cfg.H_cycles * (cfg.L_cycles + 1) // t_ons) // 2 + + for act_step in range(cfg.halt_max_steps): + z_H_prev = z_H.detach().clone() + z_L_prev = z_L.detach().clone() + + with torch.enable_grad(): + zH, zL = z_H.detach(), z_L.detach() + for _h in range(cfg.H_cycles): + # L cycles + for _l in range(cfg.L_cycles): + v_H_j = Q[:, :D, :] + v_L_j = Q[:, D:, :] + v_comb = v_H_j + v_L_j + new_v_L_cols = [] + f_L = lambda z: inner.L_level(z, zH + input_embeddings, **seq_info) + for i in range(k_lyap): + v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) + zL_new, Dv = jvp_through(f_L, zL, v_i) + new_v_L_cols.append(Dv.reshape(B, D).to(torch.float32)) + new_v_L = torch.stack(new_v_L_cols, dim=-1) + Q = torch.cat([v_H_j, new_v_L], dim=1) + zL = zL_new + step_counter += 1 + if step_counter % t_ons == 0: + Q, R = torch.linalg.qr(Q) + log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() + n_lyap_steps += 1 + R_list.append(R.detach()) + if n_lyap_steps == target_qr: + Q_mid = Q.detach().clone() + + # H step (uses SAME L_level!) + v_H_j = Q[:, :D, :] + v_L_j = Q[:, D:, :] + v_comb = v_H_j + v_L_j + new_v_H_cols = [] + f_H = lambda z: inner.L_level(z, zL, **seq_info) + for i in range(k_lyap): + v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) + zH_new, Dv = jvp_through(f_H, zH, v_i) + new_v_H_cols.append(Dv.reshape(B, D).to(torch.float32)) + new_v_H = torch.stack(new_v_H_cols, dim=-1) + Q = torch.cat([new_v_H, v_L_j], dim=1) + zH = zH_new + step_counter += 1 + if step_counter % t_ons == 0: + Q, R = torch.linalg.qr(Q) + log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() + n_lyap_steps += 1 + R_list.append(R.detach()) + if n_lyap_steps == target_qr: + Q_mid = Q.detach().clone() + + z_H, z_L = zH, zL + + drift_zH_per_step.append((z_H - z_H_prev).float().flatten(1).norm(dim=1).cpu()) + drift_zL_per_step.append((z_L - z_L_prev).float().flatten(1).norm(dim=1).cpu()) + + with torch.no_grad(): + q_logits = inner.q_head(z_H[:, 0]).float() + q_halt, q_continue = q_logits[..., 0], q_logits[..., 1] + q_halt_hist.append(q_halt.cpu()); q_continue_hist.append(q_continue.cpu()) + new_halt = (q_halt > q_continue) & (halted_at == 0) + halted_at[new_halt] = act_step + 1 + output = inner.lm_head(z_H)[:, inner.puzzle_emb_len:].float() + final_logits = output + + lyap_spec = (log_R_sum / max(n_lyap_steps, 1)).cpu().numpy() + + # Ginelli CLV backward pass: iterate C = normalize(R^{-1} C) from the end back to the + # midpoint analysis time, then CLVs = Q_mid @ C. Geometry NOT captured by exponents: + # pairwise CLV angles (Oseledets-splitting tangency). + with torch.no_grad(): + if Q_mid is None: + Q_mid = Q # fallback (shouldn't happen with no early halt) + g2 = torch.Generator(device=device).manual_seed(seed + 9973) + C = torch.triu(torch.randn(B, k_lyap, k_lyap, device=device, dtype=torch.float32, generator=g2)) + C = C / C.norm(dim=1, keepdim=True).clamp_min(1e-20) + n_used = 0 + for R in reversed(R_list[target_qr:]): + Rf = R.float() + # guard against tiny diagonal (rank-deficient tangent) before triangular solve + diag = Rf.diagonal(dim1=-2, dim2=-1) + Rf = Rf + torch.diag_embed(torch.where(diag.abs() < 1e-8, torch.full_like(diag, 1e-8), torch.zeros_like(diag))) + C = torch.linalg.solve_triangular(Rf, C, upper=True) + C = C / C.norm(dim=1, keepdim=True).clamp_min(1e-20) + n_used += 1 + V = torch.matmul(Q_mid.float(), C) # (B, 2D, k) covariant LVs + V = V / V.norm(dim=1, keepdim=True).clamp_min(1e-20) # unit columns (non-orthogonal) + gram = torch.matmul(V.transpose(1, 2), V).abs() # (B, k, k) |cos| between CLVs + eye = torch.eye(k_lyap, device=device).unsqueeze(0) + off = gram * (1 - eye) + clv_maxcos = off.amax(dim=(1, 2)).cpu().numpy() # tangency: max |cos| (1=degenerate) + clv_meancos = (off.sum(dim=(1, 2)) / (k_lyap * (k_lyap - 1))).cpu().numpy() + clv_minangle = torch.acos(off.amax(dim=(1, 2)).clamp(max=1 - 1e-7)).cpu().numpy() + # leading TRUE CLV localization/readout (vs the leading-GS-vector version below) + Dh0 = seq_full * hidden; pe0 = inner.puzzle_emb_len + vH = V[:, :Dh0, 0].reshape(B, seq_full, hidden) + eH = (vH ** 2).sum(-1); eL = (V[:, Dh0:, 0].reshape(B, seq_full, hidden) ** 2).sum(-1) + e = eH + eL + clv_true_hfrac = (eH.sum(-1) / e.sum(-1).clamp_min(1e-20)).cpu().numpy() + clv_true_pr_ans = (e[:, pe0:].sum(-1) ** 2 / (e[:, pe0:] ** 2).sum(-1).clamp_min(1e-20) / max(seq_full - pe0, 1)).cpu().numpy() + ro = torch.matmul(vH[:, pe0:], inner.lm_head.weight.float().t()) + clv_true_readout = (ro.flatten(1).norm(dim=1) / vH[:, pe0:].flatten(1).norm(dim=1).clamp_min(1e-20)).cpu().numpy() + + with torch.no_grad(): + preds = final_logits.argmax(dim=-1) + labels = batch["labels"].to(device) + mask = labels > 0 + exact = ((preds == labels) | ~mask).all(dim=-1).cpu().float() + token_acc = ((preds == labels) & mask).sum(-1).float() / mask.sum(-1).float().clamp_min(1) + token_acc = token_acc.cpu() + + # CLV geometry: leading columns of final Q are the leading covariant Lyapunov vectors. + # Features are DIRECTION-based (not scalar growth rates): H/L energy split, token + # participation ratio (localization), and alignment with the decode/readout direction. + with torch.no_grad(): + Dh = seq_full * hidden + n_top = min(3, k_lyap) + pe = inner.puzzle_emb_len + W = inner.lm_head.weight.float() # (vocab, hidden), bias-free + hfrac_c, pr_c, prans_c, ro_c = [], [], [], [] + for j in range(n_top): + col = Q[:, :, j].float() # (B, 2Dh), unit norm + qH = col[:, :Dh].reshape(B, seq_full, hidden) + qL = col[:, Dh:].reshape(B, seq_full, hidden) + eH = (qH ** 2).sum(-1); eL = (qL ** 2).sum(-1) # (B, seq_full) per-token energy + e = eH + eL + hfrac_c.append((eH.sum(-1) / e.sum(-1).clamp_min(1e-20)).cpu()) + pr_c.append((e.sum(-1) ** 2 / (e ** 2).sum(-1).clamp_min(1e-20) / seq_full).cpu()) + ea = e[:, pe:] + prans_c.append((ea.sum(-1) ** 2 / (ea ** 2).sum(-1).clamp_min(1e-20) / max(seq_full - pe, 1)).cpu()) + ro = torch.matmul(qH[:, pe:], W.t()) # (B, ans, vocab) readout projection + ro_c.append((ro.flatten(1).norm(dim=1) / qH[:, pe:].flatten(1).norm(dim=1).clamp_min(1e-20)).cpu()) + + return { + "drift_zH": torch.stack(drift_zH_per_step, dim=1).numpy(), + "drift_zL": torch.stack(drift_zL_per_step, dim=1).numpy(), + "halted_at": halted_at.cpu().numpy(), + "q_halt": torch.stack(q_halt_hist, dim=1).numpy(), + "q_continue": torch.stack(q_continue_hist, dim=1).numpy(), + "lyap_spec": lyap_spec, + "exact_correct": exact.numpy(), + "token_acc": token_acc.numpy(), + "clv_hfrac": torch.stack(hfrac_c, 1).numpy(), + "clv_pr": torch.stack(pr_c, 1).numpy(), + "clv_pr_ans": torch.stack(prans_c, 1).numpy(), + "clv_readout": torch.stack(ro_c, 1).numpy(), + "clv_maxcos": clv_maxcos, + "clv_meancos": clv_meancos, + "clv_minangle": clv_minangle, + "clv_true_hfrac": clv_true_hfrac, + "clv_true_pr_ans": clv_true_pr_ans, + "clv_true_readout": clv_true_readout, + } + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--ckpt-root", required=True) + ap.add_argument("--ckpt-name", default="step_13020") + ap.add_argument("--n-samples", type=int, default=512) + ap.add_argument("--shard-id", type=int, default=0) + ap.add_argument("--num-shards", type=int, default=1) + ap.add_argument("--batch-size", type=int, default=16) + ap.add_argument("--k-lyap", type=int, default=8) + ap.add_argument("--t-ons", type=int, default=1) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--out", default="diag_trm.npz") + args = ap.parse_args() + + device = "cuda" + model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device) + print(f"loaded {args.ckpt_name}: hidden={model.inner.config.hidden_size}, " + f"seq_full={train_meta['seq_len'] + model.inner.puzzle_emb_len}, " + f"halt_max_steps={model.inner.config.halt_max_steps}, " + f"H={model.inner.config.H_cycles} L={model.inner.config.L_cycles}") + + test = load_test_samples(Path(cfg["data_paths"][0]), args.n_samples, args.shard_id, args.num_shards, args.seed) + n = len(test["inputs"]) + print(f"shard {args.shard_id}/{args.num_shards}: {n} samples") + + res = {k: [] for k in ["drift_zH","drift_zL","halted_at","q_halt","q_continue","lyap_spec","exact_correct","token_acc","idx"]} + t0 = time.time() + for s in range(0, n, args.batch_size): + e = min(s + args.batch_size, n) + batch = {k: test[k][s:e].to(device) for k in ["inputs","labels","puzzle_identifiers"]} + out = run_diagnose_batch(model, batch, device, args.k_lyap, args.t_ons, args.seed + s) + for k, v in out.items(): + res.setdefault(k, []).append(v) + res.setdefault("idx", []).append(test["idx"][s:e]) + ls = out["lyap_spec"] + print(f" [{e}/{n}] dt={time.time()-t0:.1f}s exact={out['exact_correct'].mean():.3f} " + f"λ_1={ls[:,0].mean():+.4f} λ_{args.k_lyap}={ls[:,-1].mean():+.4f}", flush=True) + + saved = {} + for k, v in res.items(): + if not v: continue + try: saved[k] = np.concatenate(v, 0) + except ValueError: saved[k] = np.stack(v, 0) + np.savez_compressed(args.out, **saved) + succ = saved["exact_correct"] > 0.5 + print(f"\nN={len(succ)} acc={succ.mean():.4f}") + print(f"{'i':>3} {'all':>10} {'succ':>10} {'fail':>10} {'Δ':>9}") + for i in range(saved["lyap_spec"].shape[1]): + li = saved["lyap_spec"][:, i] + print(f"{i+1:>3} {li.mean():+10.4f} {li[succ].mean():+10.4f} {li[~succ].mean():+10.4f} {li[~succ].mean()-li[succ].mean():+9.4f}") + + +if __name__ == "__main__": + main() |
