"""Build notebooks/recursive_reasoning_chaos.ipynb via nbformat. Order: load a trained TRM/HRM from HuggingFace -> (1) THE core result: leading finite-time Lyapunov exponent (lambda_1) along the inference trajectory separates success from failure (failures more chaotic) -> (2) why: transient chaos, failures escape with more compute -> (3) basin accessibility. """ import nbformat as nbf from pathlib import Path HF_REPO = "blackhao0426/recursive-reasoning-chaos" # HF account (GitHub is YurenHao0426) nb = nbf.v4.new_notebook() C = [] def md(t): C.append(nbf.v4.new_markdown_cell(t)) def code(t): C.append(nbf.v4.new_code_cell(t)) md(f"""# Recursive Reasoning Failures are Chaotic Small recursive reasoners (HRM, TRM) iterate a latent state to solve a puzzle (Sudoku) before emitting an answer. **The core finding:** measured along the recurrent inference trajectory, the **leading finite-time Lyapunov exponent (λ₁) is higher on failed examples than on successful ones** — in the *same* trained network. Failure is locally more chaotic. This notebook, top to bottom: 1. **Load** a trained model from HuggingFace (`{HF_REPO}`). 2. **The result** — compute λ₁ per example (Benettin / JVP along the trajectory) and show the success-vs-failure separation (histogram + AUC). *This is the headline.* 3. **Why** — run the recurrence far past its training budget: failures are a *transient* — they escape the chaotic set and self-correct given enough compute (TRM), or stay trapped much longer (HRM). Neither settles to a wrong fixed point. 4. **Basin accessibility** — restart from a perturbed initial state: is a failure input-determined or initial-condition-determined? Companion analysis repo: `github.com/YurenHao0426/recursive-reasoning-dynamics`.""") md("## 0. Setup") code("""%pip install -q torch einops pydantic huggingface_hub numpy matplotlib tqdm import numpy as np, matplotlib.pyplot as plt, torch from tqdm.auto import tqdm print("torch", torch.__version__, "| cuda", torch.cuda.is_available()) if not torch.cuda.is_available(): print("\\n⚠️ No GPU detected — the JVP/rollout cells will be SLOW on CPU.") print(" Colab: Runtime → Change runtime type → Hardware accelerator → GPU (T4), then re-run.")""") md(f"""## 1. Load a trained model from HuggingFace Downloads model code + checkpoint + a 2000-puzzle test set from `{HF_REPO}`. `MODEL` ∈ {{`trm_sudoku`, `hrm_sudoku`}}. **TRM is MLP-only → runs on a laptop CPU.** To switch models, change `MODEL` and **restart the kernel** (the two ship same-named `models` packages).""") code(f"""import sys, yaml, json from pathlib import Path from huggingface_hub import snapshot_download HF_REPO = "{HF_REPO}" MODEL = "trm_sudoku" # or "hrm_sudoku" root = Path(snapshot_download(HF_REPO)) sys.path.insert(0, str(root / ("code_trm" if MODEL.startswith("trm") else "code_hrm"))) cfg = yaml.safe_load((root / MODEL / "all_config.yaml").read_text()) meta = json.loads((root / "data" / "sudoku_meta.json").read_text()) arch = dict(cfg["arch"]); arch.update(batch_size=64, seq_len=meta["seq_len"], vocab_size=meta["vocab_size"], num_puzzle_identifiers=meta["num_puzzle_identifiers"], causal=False) if MODEL.startswith("trm"): from models.recursive_reasoning.trm import TinyRecursiveReasoningModel_ACTV1 as M else: from models.hrm.hrm_act_v1 import HierarchicalReasoningModel_ACTV1 as M model = M(arch) sd = torch.load(root / MODEL / "weights.pt", map_location="cpu", weights_only=True) model.load_state_dict({{k.replace("_orig_mod.","").replace("model.",""): v for k,v in sd.items()}}, strict=False) dev = "cuda" if torch.cuda.is_available() else "cpu"; model.to(dev).eval() inner = model.inner inp = np.load(root/"data"/"sudoku_test_inputs.npy"); lab = np.load(root/"data"/"sudoku_test_labels.npy") pid = np.load(root/"data"/"sudoku_test_pid.npy") print(f"loaded {{MODEL}}: hidden={{inner.config.hidden_size}}, H_cycles={{inner.config.H_cycles}}, L_cycles={{inner.config.L_cycles}}, test puzzles={{len(inp)}}")""") md("""## 2. The core result — failures are more chaotic (leading FTLE / λ₁) For each puzzle we run the recurrence for the 16-segment inference budget and propagate one tangent vector through every module update (forward-mode JVP), renormalizing each step and accumulating the log-growth (Benettin's method for the largest exponent). λ₁ = mean log-growth per module-evaluation. Then split by outcome at segment 16. **Failures sit at higher λ₁** — they are locally expanding / chaotic; successes have collapsed toward the solution. `AUC(−λ₁ → success)` near 1 = clean separation. The JVP uses the fact that the update is `module(a, b) = layers(a + b)`, so a perturbation of the combined input can be fed through one slot. (GPU: ~1 min. Laptop/CPU with TRM: a few min — lower `n`.)""") code("""import torch.autograd.functional as AF from contextlib import nullcontext try: # HRM attention JVP needs the math SDP backend (no FlashAttn double-backward) from torch.nn.attention import sdpa_kernel, SDPBackend MATHCTX = lambda: sdpa_kernel(SDPBackend.MATH) except Exception: MATHCTX = nullcontext def auc(score, y): p, n = score[y==1], score[y==0] if len(p)==0 or len(n)==0: return float('nan') a=np.concatenate([p,n]); o=np.argsort(a); r=np.empty(len(a)); r[o]=np.arange(1,len(a)+1) return (r[:len(p)].sum()-len(p)*(len(p)+1)/2)/(len(p)*len(n)) def leading_ftle(inp, lab, pid, n=128, n_seg=16, seed=0): rng=np.random.default_rng(seed); idx=rng.choice(len(inp), n, replace=False) pe=inner.puzzle_emb_len; sf=inner.config.seq_len+pe; hid=inner.config.hidden_size; D=sf*hid; B=n is_hrm=hasattr(inner,"H_level") and getattr(inner,"H_level",None) is not None Hmod=inner.H_level if is_hrm else inner.L_level # weight-tied TRM reuses L_level X=torch.tensor(inp[idx].astype(np.int32),device=dev); Y=torch.tensor(lab[idx].astype(np.int32),device=dev) P=torch.tensor(pid[idx].astype(np.int32),device=dev) si=dict(cos_sin=inner.rotary_emb() if hasattr(inner,"rotary_emb") else None) g=torch.Generator(device=dev).manual_seed(seed) jvp=lambda f,x,v: AF.jvp(f, x, v=v, create_graph=False, strict=False) def renorm(vH,vL): nrm=torch.sqrt(vH.pow(2).sum(1,keepdim=True)+vL.pow(2).sum(1,keepdim=True)).clamp_min(1e-30) return vH/nrm, vL/nrm, nrm.squeeze(1) with MATHCTX(): emb=inner._input_embeddings(X,P); m=Y>0 zH=inner.H_init.unsqueeze(0).expand(B,sf,hid).clone().to(inner.forward_dtype) zL=inner.L_init.unsqueeze(0).expand(B,sf,hid).clone().to(inner.forward_dtype) vH=torch.randn(B,D,device=dev,generator=g); vL=torch.randn(B,D,device=dev,generator=g) vH,vL,_=renorm(vH,vL); logsum=torch.zeros(B,device=dev); nstep=0 for seg in tqdm(range(n_seg), desc="FTLE (segments)"): with torch.enable_grad(): zH,zL=zH.detach(),zL.detach() for _h in range(inner.config.H_cycles): for _l in range(inner.config.L_cycles): vc=(vH+vL).reshape(B,sf,hid).to(inner.forward_dtype) zL,Dv=jvp(lambda z: inner.L_level(z, zH+emb, **si), zL, vc) vL=Dv.reshape(B,D).float(); vH,vL,grow=renorm(vH,vL); logsum+=grow.log(); nstep+=1 vc=(vH+vL).reshape(B,sf,hid).to(inner.forward_dtype) zH,Dv=jvp(lambda z: Hmod(z, zL, **si), zH, vc) vH=Dv.reshape(B,D).float(); vH,vL,grow=renorm(vH,vL); logsum+=grow.log(); nstep+=1 ftle=(logsum/nstep).cpu().numpy() ok=(((inner.lm_head(zH)[:,pe:].float().argmax(-1)==Y)|~m).all(-1)).cpu().numpy() return ftle, ok ftle, succ = leading_ftle(inp, lab, pid, n=64) # n=64 already separates cleanly; raise for tighter histograms print(f"success rate {succ.mean():.2f} | median λ1 success {np.median(ftle[succ]):+.4f} vs failure {np.median(ftle[~succ]):+.4f}") print(f"AUC(-λ1 -> success) = {auc(-ftle, succ.astype(int)):.3f} (>0.5 means failures are more chaotic)") plt.figure(figsize=(6,4)) b=np.linspace(ftle.min(), ftle.max(), 40) plt.hist(ftle[succ], b, alpha=.6, color='g', density=True, label=f'success (n={succ.sum()})') plt.hist(ftle[~succ], b, alpha=.6, color='r', density=True, label=f'failure (n={(~succ).sum()})') plt.axvline(0, ls=':', c='k', lw=1) plt.xlabel('leading finite-time Lyapunov exponent λ1'); plt.ylabel('density') plt.title(f'{MODEL}: failures are more chaotic'); plt.legend(); plt.tight_layout(); plt.show()""") md("""## 3. Why — transient chaos: failures *escape* with more compute Run the recurrence `N_SEG` segments (far past the 16-segment budget) and watch the fate of trajectories that fail at segment 16. **TRM** failures escape the chaotic transient and resolve to the correct answer; **HRM** failures are far more strongly trapped. Re-run cell 1 with `MODEL="hrm_sudoku"` (restart kernel) to compare.""") code("""def extended_rollout(inp, lab, pid, n=256, n_seg=128, seed=0): rng=np.random.default_rng(seed); idx=rng.choice(len(inp), n, replace=False) pe=inner.puzzle_emb_len; sf=inner.config.seq_len+pe; hid=inner.config.hidden_size is_hrm=hasattr(inner, "H_level") and getattr(inner,"H_level",None) is not None Hmod=inner.H_level if is_hrm else inner.L_level X=torch.tensor(inp[idx].astype(np.int32),device=dev); Y=torch.tensor(lab[idx].astype(np.int32),device=dev) P=torch.tensor(pid[idx].astype(np.int32),device=dev) EX=[]; DR=[] with torch.no_grad(): zH=inner.H_init.unsqueeze(0).expand(n,sf,hid).clone().to(inner.forward_dtype) zL=inner.L_init.unsqueeze(0).expand(n,sf,hid).clone().to(inner.forward_dtype) si=dict(cos_sin=inner.rotary_emb() if hasattr(inner,"rotary_emb") else None) emb=inner._input_embeddings(X,P); m=Y>0; prev=None for _ in tqdm(range(n_seg), desc="rollout (segments)"): for _h in range(inner.config.H_cycles): for _l in range(inner.config.L_cycles): zL=inner.L_level(zL, zH+emb, **si) zH=Hmod(zH, zL, **si) p=inner.lm_head(zH)[:,pe:].float().argmax(-1) EX.append(((p==Y)|~m).all(-1).float().cpu().numpy()) DR.append((torch.zeros(n) if prev is None else (zH-prev).float().flatten(1).norm(dim=1).cpu()).numpy()) prev=zH.detach() return np.stack(EX,1), np.stack(DR,1) ex, dr = extended_rollout(inp, lab, pid, n=256, n_seg=128) T=ex.shape[1]; fail=ex[:,15]==0; nf=fail.sum() print(f"accuracy @16={ex[:,15].mean():.3f} @{T}={ex[:,-1].mean():.3f}") print(f"of {nf} step-16 failures: self-resolve to CORRECT by seg{T}: {(fail&(ex[:,-1]==1)).sum()/max(nf,1)*100:.0f}%") fig,ax=plt.subplots(1,2,figsize=(11,4)) ax[0].plot(range(1,T+1), ex.mean(0)); ax[0].axvline(16,ls='--',c='gray'); ax[0].set_xscale('log') ax[0].set_xlabel('inference segments'); ax[0].set_ylabel('accuracy'); ax[0].set_title('accuracy vs compute') S=[(fail&(ex[:,:s].max(1)==0)).sum()/max(nf,1) for s in range(16,T+1)] ax[1].plot(range(16,T+1),S); ax[1].set_yscale('log'); ax[1].set_xlabel('segments'); ax[1].set_ylabel('frac failures still unsolved') ax[1].set_title('escape from chaotic set (straight line on log-y = exponential escape)'); plt.tight_layout(); plt.show()""") md("""## 4. Basin accessibility — input-determined or initial-condition-determined? The puzzle is re-injected every segment (`z_H + input_embeddings`), so perturbing only the **initial** latent `z0` is a clean initial-condition change that leaves the input intact. Restart each step-16 failure `K` times from `z0 + sigma*noise`: if a small kick frees it, the solution basin is large and accessible (TRM); if no nearby IC escapes, the trapping is input-determined (HRM has a hard core). (GPU: seconds. Laptop/CPU with TRM: a couple of minutes — lower `n`/`K`.)""") code("""def perturb_z0(inp, lab, pid, n=96, K=8, sigmas=(0.0, 0.1, 0.3, 1.0), n_seg=48, readout=16, seed=0): rng=np.random.default_rng(seed); idx=rng.choice(len(inp), n, replace=False) pe=inner.puzzle_emb_len; sf=inner.config.seq_len+pe; hid=inner.config.hidden_size is_hrm=hasattr(inner,"H_level") and getattr(inner,"H_level",None) is not None Hmod=inner.H_level if is_hrm else inner.L_level sc=float(inner.H_init.float().std()); g=torch.Generator(device=dev).manual_seed(seed) X=torch.tensor(inp[idx].astype(np.int32),device=dev); Y=torch.tensor(lab[idx].astype(np.int32),device=dev) P=torch.tensor(pid[idx].astype(np.int32),device=dev) si=dict(cos_sin=inner.rotary_emb() if hasattr(inner,"rotary_emb") else None) solve=np.zeros((n,len(sigmas),K),bool); base=None with torch.no_grad(): emb0=inner._input_embeddings(X,P); m0=Y>0 for sj,sg in enumerate(tqdm(sigmas, desc="basin (sigma levels)")): emb=emb0.repeat_interleave(K,0); Yr=Y.repeat_interleave(K,0); mr=m0.repeat_interleave(K,0); B=n*K zH=inner.H_init.unsqueeze(0).expand(B,sf,hid).clone().to(inner.forward_dtype) zL=inner.L_init.unsqueeze(0).expand(B,sf,hid).clone().to(inner.forward_dtype) if sg>0: zH=(zH.float()+sg*sc*torch.randn(zH.shape,generator=g,device=dev)).to(inner.forward_dtype) zL=(zL.float()+sg*sc*torch.randn(zL.shape,generator=g,device=dev)).to(inner.forward_dtype) solved=torch.zeros(B,dtype=torch.bool,device=dev) for s in range(n_seg): for _h in range(inner.config.H_cycles): for _l in range(inner.config.L_cycles): zL=inner.L_level(zL,zH+emb,**si) zH=Hmod(zH,zL,**si) ok=((inner.lm_head(zH)[:,pe:].float().argmax(-1)==Yr)|~mr).all(-1); solved|=ok if sj==0 and s==readout-1: base=ok.view(n,K)[:,0].cpu().numpy() solve[:,sj]=solved.view(n,K).cpu().numpy() return solve, base, np.array(sigmas) solve, base, sg = perturb_z0(inp, lab, pid) fail=~base; nf=int(fail.sum()) print(f"{nf} of {len(base)} puzzles fail@16; freeing them by restarting from a perturbed IC:") for j,s in enumerate(sg): sub=solve[fail,j]; print(f" sigma={s:.1f}: single-restart={sub.mean():.2f} best-of-K={sub.any(1).mean():.2f}") plt.figure(figsize=(6,4)) plt.plot(sg,[solve[fail,j].mean() for j in range(len(sg))],'o--',label='single restart') plt.plot(sg,[solve[fail,j].any(1).mean() for j in range(len(sg))],'s-',label='best-of-K') plt.xlabel('relative IC noise sigma'); plt.ylabel('solve rate (failing puzzles)') plt.title('basin accessibility: does a restart free a trapped puzzle?'); plt.legend(); plt.grid(alpha=.3); plt.show()""") md("""## What this shows - **The result (cell 2):** in the same trained network, failed trajectories have a higher leading finite-time Lyapunov exponent than successful ones — failure is locally more chaotic. - **Why (cell 3):** that chaos is a *transient*. Failures sit on a chaotic **saddle**, not a wrong fixed point — TRM's escape and self-correct with more compute; HRM's are much more strongly trapped (still a saddle, just a far smaller escape rate). The per-segment escape gap is mostly compute-per-segment (TRM evaluates its module 21×/segment vs HRM 6×; per module-eval the gap is only ~1.6×). The "spurious fixed point" reading from 2D PCA is an artifact of projecting high-dimensional chaotic wandering. - **Basin (cell 4):** a small initial-condition kick frees most of TRM's trapped puzzles (IC-determined, large basin); a hard core of HRM's escapes no nearby IC (input-determined). Try: change `MODEL` (restart kernel), `n`/`n_seg`, and compare TRM vs HRM at every step.""") nb["cells"] = C out = Path(__file__).resolve().parent / "recursive_reasoning_chaos.ipynb" nbf.write(nb, str(out)) print("wrote", out, f"({len(C)} cells)")