"""Step 3: Continue HRM training with joint-Lyapunov reverse-flossing regularizer. L_total = L_HRM_ACT(supervised) + alpha * L_RF L_RF = mean over batch of max(0, lambda_joint_1 - lambda_star) ** 2 FORWARD direction (what we measure): - Joint tangent Q in R^{B x 2D x k} evolved via block-matrix Jacobian along z trajectory (using JVP with create_graph=True so RF loss is diff'ble in theta). - Lyapunov spectrum = (1/T) * sum_t log|R_ii(t)| from QR re-orthogonalization. BACKWARD direction (what flows to theta): - Standard autograd of L_total through the entire forward graph (including the JVP chain), as in Engelken's flossing. The QR decomposition's backward is handled by PyTorch autograd; no manual pullback needed. Loaded from intermediate checkpoint (e.g. step_18228, before the success/failure contraction gap fully forms in the original training). """ from __future__ import annotations import sys, os, yaml, json, math, time, argparse from pathlib import Path import numpy as np import torch import torch.nn.functional as F HRM_DIR = Path("/home/yurenh2/rrm/hrm") sys.path.insert(0, str(HRM_DIR)) from models.hrm.hrm_act_v1 import HierarchicalReasoningModel_ACTV1 from models.losses import ACTLossHead from adam_atan2 import AdamATan2 def load_model(ckpt_root: Path, ckpt_name: str, device: str): cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) arch_cfg = dict(cfg["arch"]) train_meta = json.loads((Path(cfg["data_path"]) / "train" / "dataset.json").read_text()) arch_cfg.update(batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"], vocab_size=train_meta["vocab_size"], num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], causal=False) base = HierarchicalReasoningModel_ACTV1(arch_cfg) head = ACTLossHead(base, loss_type=arch_cfg["loss"]["loss_type"]) sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) stripped = {k.replace("_orig_mod.", ""): v for k, v in sd.items()} missing, unexpected = head.load_state_dict(stripped, strict=False) print(f"[load {ckpt_name}] missing={len(missing)} unexpected={len(unexpected)}") head.to(device) return head, base, cfg, train_meta def jvp_train(f, x, v): """JVP that participates in the autograd graph (create_graph=True).""" return torch.autograd.functional.jvp(f, x, v=v, create_graph=True, strict=False) def compute_joint_lyap_spec(base, batch, k_lyap, lyap_act_steps, device, seed): """Returns per-sample top-k Lyapunov spectrum of joint (z_H, z_L) dynamics, differentiable wrt the model parameters. Shape (B, k).""" inner = base.inner cfg = inner.config B = batch["inputs"].shape[0] seq_full = cfg.seq_len + inner.puzzle_emb_len hidden = cfg.hidden_size D = seq_full * hidden z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None) input_embeddings = inner._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) g = torch.Generator(device=device).manual_seed(seed) Q0 = torch.randn(B, 2*D, k_lyap, device=device, dtype=torch.float32, generator=g) Q, _ = torch.linalg.qr(Q0) log_R_sum = torch.zeros(B, k_lyap, device=device, dtype=torch.float32) n_steps = 0 n_act = min(lyap_act_steps, cfg.halt_max_steps) for _act in range(n_act): for _h in range(cfg.H_cycles): for _l in range(cfg.L_cycles): v_H_j = Q[:, :D, :] v_L_j = Q[:, D:, :] v_comb = v_H_j + v_L_j new_v_L_cols = [] f_L = lambda z: inner.L_level(z, z_H + input_embeddings, **seq_info) for i in range(k_lyap): v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) z_L_new, Dv = jvp_train(f_L, z_L, v_i) new_v_L_cols.append(Dv.reshape(B, D).to(torch.float32)) new_v_L = torch.stack(new_v_L_cols, dim=-1) Q = torch.cat([v_H_j, new_v_L], dim=1) z_L = z_L_new Q, R = torch.linalg.qr(Q) log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() n_steps += 1 v_H_j = Q[:, :D, :] v_L_j = Q[:, D:, :] v_comb = v_H_j + v_L_j new_v_H_cols = [] f_H = lambda z: inner.H_level(z, z_L, **seq_info) for i in range(k_lyap): v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) z_H_new, Dv = jvp_train(f_H, z_H, v_i) new_v_H_cols.append(Dv.reshape(B, D).to(torch.float32)) new_v_H = torch.stack(new_v_H_cols, dim=-1) Q = torch.cat([new_v_H, v_L_j], dim=1) z_H = z_H_new Q, R = torch.linalg.qr(Q) log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() n_steps += 1 lyap_spec = log_R_sum / max(n_steps, 1) # (B, k) return lyap_spec # full spectrum (B, k) def load_train_batches(data_path: Path, batch_size: int, n_iters: int, seed: int = 0): rng = np.random.default_rng(seed) inputs = np.load(data_path / "train" / "all__inputs.npy") labels = np.load(data_path / "train" / "all__labels.npy") pid = np.load(data_path / "train" / "all__puzzle_identifiers.npy") N = len(inputs) for _ in range(n_iters): idx = rng.choice(N, size=batch_size, replace=False) yield { "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), "labels": torch.from_numpy(labels[idx].astype(np.int32)), "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)), } def evaluate(head, base, data_path, n_samples, batch_size, device, seed=42): rng = np.random.default_rng(seed) inputs = np.load(data_path / "test" / "all__inputs.npy") labels = np.load(data_path / "test" / "all__labels.npy") pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy") idx_all = rng.choice(len(inputs), size=n_samples, replace=False) head.eval() correct = 0; token_correct = 0; token_total = 0 for s in range(0, n_samples, batch_size): e = min(s + batch_size, n_samples) idx = idx_all[s:e] batch = { "inputs": torch.from_numpy(inputs[idx].astype(np.int32)).to(device), "labels": torch.from_numpy(labels[idx].astype(np.int32)).to(device), "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)).to(device), } with torch.no_grad(): with torch.device(device): carry = base.initial_carry(batch) for _ in range(base.config.halt_max_steps): carry, outputs = base(carry=carry, batch=batch) preds = outputs["logits"].argmax(dim=-1) mask = batch["labels"] > 0 exact = ((preds == batch["labels"]) | ~mask).all(dim=-1).float() correct += exact.sum().item() token_correct += ((preds == batch["labels"]) & mask).sum().item() token_total += mask.sum().item() return correct / n_samples, token_correct / max(token_total, 1) def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt-root", required=True) ap.add_argument("--ckpt-name", default="step_18228", help="start from this checkpoint and continue training") ap.add_argument("--n-steps", type=int, default=2000) ap.add_argument("--batch-size", type=int, default=8) ap.add_argument("--lr", type=float, default=1e-5) ap.add_argument("--alpha-rf", type=float, default=0.0, help="RF weight; 0 = baseline") ap.add_argument("--lambda-star", type=float, default=-0.05, help="joint Lyapunov target. λ_joint_1 should be < λ_star for stable joint dynamics.") ap.add_argument("--rf-mode", choices=["fixed","volume_cf","gelu","engelken_l2"], default="fixed", help="fixed: max(0,λ-λ*)² hinge (one-sided). " "volume_cf: max(0, mean_i λ_i - λ*)² over the measured top-k spectrum. " "gelu: GeLU(λ) attractor at -0.75 (deprecated, kills success). " "engelken_l2: (1/k) Σ λ_i² across full top-k (Engelken 2023, two-sided to 0).") ap.add_argument("--k-lyap", type=int, default=2) ap.add_argument("--lyap-act-steps", type=int, default=4) ap.add_argument("--seed", type=int, default=42) ap.add_argument("--eval-every", type=int, default=200) ap.add_argument("--eval-n", type=int, default=512) ap.add_argument("--eval-batch-size", type=int, default=32) ap.add_argument("--out", default="step3_log.json") args = ap.parse_args() device = "cuda" head, base, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device) optim = AdamATan2(head.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=cfg["weight_decay"]) print(f"\n=== Initial eval (loaded {args.ckpt_name}) ===") acc0, tacc0 = evaluate(head, base, Path(cfg["data_path"]), args.eval_n, args.eval_batch_size, device) print(f" initial: exact_acc = {acc0:.4f} token_acc = {tacc0:.4f}") log = {"args": vars(args), "initial_acc": acc0, "initial_tok_acc": tacc0, "steps": [], "evals": []} log["evals"].append({"step": 0, "acc": acc0, "tok_acc": tacc0}) t0 = time.time() train_iter = load_train_batches(Path(cfg["data_path"]), args.batch_size, args.n_steps, seed=args.seed) for step, batch in enumerate(train_iter): batch = {k: v.to(device) for k, v in batch.items()} head.train() # Sparse puzzle embedding has fixed local_weights buffer (training batch=768); # keep it in eval mode (still uses the weights table, just not the local buffer). base.inner.puzzle_emb.eval() for p in base.inner.puzzle_emb.parameters(): p.requires_grad_(False) # ---- Supervised ACT loss (accumulated over all halt_max_steps ACT steps) ---- with torch.device(device): carry = base.initial_carry(batch) sup_loss_sum = 0.0 n_loss = 0 for _ in range(base.config.halt_max_steps): carry, l, metrics, _, all_finish = head(return_keys=[], carry=carry, batch=batch) sup_loss_sum = sup_loss_sum + l n_loss += 1 if all_finish: break sup_loss = sup_loss_sum / max(n_loss, 1) / args.batch_size # ---- Reverse-flossing penalty (only if alpha > 0) ---- if args.alpha_rf > 0: lyap_spec = compute_joint_lyap_spec( base, batch, k_lyap=args.k_lyap, lyap_act_steps=args.lyap_act_steps, device=device, seed=args.seed + step, ) # (B, k) lyap1 = lyap_spec[:, 0] if args.rf_mode == "engelken_l2": # Engelken 2023: push all top-k λ_i² → 0 (two-sided) rf_loss = (lyap_spec ** 2).mean() excess = lyap1 # log signed λ_1 elif args.rf_mode == "volume_cf": # One-sided cap on local phase-space volume expansion. # Allows λ_1 > 0 when compensated by enough contraction in other modes. lyap_volume = lyap_spec.mean(dim=1) excess = (lyap_volume - args.lambda_star).clamp_min(0.0) rf_loss = (excess ** 2).mean() elif args.rf_mode == "gelu": rf_loss = torch.nn.functional.gelu(lyap1).mean() excess = lyap1 - (-0.751) else: # fixed (hinge on top-1) excess = (lyap1 - args.lambda_star).clamp_min(0.0) rf_loss = (excess ** 2).mean() else: if args.k_lyap > 0: with torch.no_grad(): lyap_spec = compute_joint_lyap_spec( base, batch, k_lyap=args.k_lyap, lyap_act_steps=args.lyap_act_steps, device=device, seed=args.seed + step, ) lyap1 = lyap_spec[:, 0] else: # Fast alpha=0 baseline path: skip diagnostic-only Lyapunov work. lyap1 = torch.zeros(batch["inputs"].shape[0], device=device) lyap_spec = lyap1[:, None] rf_loss = torch.zeros((), device=device) excess = (lyap1 - args.lambda_star).clamp_min(0.0) total_loss = sup_loss + args.alpha_rf * rf_loss optim.zero_grad(set_to_none=True) total_loss.backward() torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0) optim.step() rec = { "step": step, "sup_loss": float(sup_loss.item()), "rf_loss": float(rf_loss.item()), "total_loss": float(total_loss.item()), "lyap1_mean": float(lyap1.detach().mean().item()), "lyap1_max": float(lyap1.detach().max().item()), "lyap_volume_mean": float(lyap_spec.detach().mean(dim=1).mean().item()), "lyap_volume_max": float(lyap_spec.detach().mean(dim=1).max().item()), "frac_above_star": float((excess > 0).float().mean().item()), } log["steps"].append(rec) if step % 10 == 0 or step == args.n_steps - 1: print(f" [{step:>4}/{args.n_steps}] dt={time.time()-t0:.1f}s " f"sup={rec['sup_loss']:.4f} rf={rec['rf_loss']:.4f} " f"λj1_mean={rec['lyap1_mean']:+.4f} max={rec['lyap1_max']:+.4f} " f"frac>λ*={rec['frac_above_star']:.2f}", flush=True) if (step + 1) % args.eval_every == 0: acc, tacc = evaluate(head, base, Path(cfg["data_path"]), args.eval_n, args.eval_batch_size, device) print(f" >> EVAL @ step {step+1}: exact_acc={acc:.4f} (Δ from init: {acc-acc0:+.4f})", flush=True) log["evals"].append({"step": step + 1, "acc": acc, "tok_acc": tacc}) acc_f, tacc_f = evaluate(head, base, Path(cfg["data_path"]), args.eval_n, args.eval_batch_size, device) print(f"\n=== Final eval ===") print(f" initial: {acc0:.4f} final: {acc_f:.4f} (Δ {acc_f-acc0:+.4f})") log["final_acc"] = acc_f log["final_tok_acc"] = tacc_f log["evals"].append({"step": args.n_steps, "acc": acc_f, "tok_acc": tacc_f}) Path(args.out).write_text(json.dumps(log, indent=2)) print(f"log → {args.out}") if __name__ == "__main__": main()