diff options
Diffstat (limited to 'notebooks/recursive_reasoning_chaos.ipynb')
| -rw-r--r-- | notebooks/recursive_reasoning_chaos.ipynb | 220 |
1 files changed, 220 insertions, 0 deletions
diff --git a/notebooks/recursive_reasoning_chaos.ipynb b/notebooks/recursive_reasoning_chaos.ipynb new file mode 100644 index 0000000..0751ae3 --- /dev/null +++ b/notebooks/recursive_reasoning_chaos.ipynb @@ -0,0 +1,220 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0347ea73", + "metadata": {}, + "source": [ + "# Recursive Reasoning Failures are Chaotic — and it's *transient chaos*\n", + "\n", + "Small recursive reasoners (HRM, TRM) iterate a latent state to solve puzzles (Sudoku, Maze).\n", + "Measured along the inference trajectory, **failed examples are more chaotic** (higher finite-time\n", + "Lyapunov exponent / latent drift) than successful ones, in the *same* trained network.\n", + "\n", + "This notebook lets you reproduce and play with the mechanism:\n", + "1. **Toy model** (pure numpy, no GPU) — *transient chaos*: chaotic search of latent space until the\n", + " trajectory escapes into the solution basin. Failures = not-yet-escaped trajectories.\n", + "2. **Real trained model** loaded from HuggingFace (`YurenHao0426/recursive-reasoning-chaos`).\n", + "3. **Extended rollout** — run the recurrence far beyond its training budget. **TRM** failures escape\n", + " (transient chaotic saddle → the model solves 96%+ given enough compute); **HRM** failures stay\n", + " trapped (a chaotic *attractor*). Neither settles to a wrong *fixed point*.\n", + "\n", + "Companion analysis repo: `github.com/YurenHao0426/recursive-reasoning-dynamics`." + ] + }, + { + "cell_type": "markdown", + "id": "23f53bbd", + "metadata": {}, + "source": [ + "## 0. Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40edaba4", + "metadata": {}, + "outputs": [], + "source": [ + "# minimal deps; torch+einops+pydantic are enough to load these models (TRM-Sudoku is MLP-mixer,\n", + "# no FlashAttention needed -> runs on any GPU, even CPU).\n", + "%pip install -q torch einops pydantic huggingface_hub numpy matplotlib\n", + "import numpy as np, matplotlib.pyplot as plt, torch\n", + "print(\"torch\", torch.__version__, \"| cuda\", torch.cuda.is_available())" + ] + }, + { + "cell_type": "markdown", + "id": "7c7960ce", + "metadata": {}, + "source": [ + "## 1. The toy model — transient chaos (no GPU, runs in seconds)\n", + "\n", + "A trajectory chaotically *searches* `[0,1]` (logistic map, λ=ln2≈+0.69) until it lands within `eps`\n", + "of the solution `s` (the \"puzzle\"), then it converges (λ=ln0.5<0). At a fixed readout time `T`:\n", + "**captured = success** (FTLE low), **still searching = failure** (FTLE high). The escape time is\n", + "~geometric (chaotic-saddle signature) and the FTLE separation is purely a *finite-time* effect —\n", + "it vanishes as `T→∞` because everyone eventually escapes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e608a245", + "metadata": {}, + "outputs": [], + "source": [ + "def run_toy(n=20000, T=16, eps=0.04, seed=0):\n", + " rg = np.random.default_rng(seed)\n", + " s = rg.uniform(0.15, 0.85, n); x = rg.uniform(0, 1, n)\n", + " captured = np.zeros(n, bool); logd = np.zeros(n)\n", + " for t in range(T):\n", + " search = ~captured\n", + " ld = np.where(search, np.log(np.abs(4*(1-2*x))+1e-12), np.log(0.5))\n", + " xn = np.where(search, 4*x*(1-x), s + 0.5*(x-s))\n", + " captured |= search & (np.abs(xn-s) < eps); x = xn; logd += ld\n", + " ftle = logd / T\n", + " success = captured & (np.abs(x-s) < 0.05)\n", + " return ftle, success\n", + "\n", + "def auc(score, y):\n", + " p, n = score[y==1], score[y==0]; a=np.concatenate([p,n]); o=np.argsort(a)\n", + " r=np.empty(len(a)); r[o]=np.arange(1,len(a)+1)\n", + " return (r[:len(p)].sum()-len(p)*(len(p)+1)/2)/(len(p)*len(n))\n", + "\n", + "ftle, succ = run_toy(T=16)\n", + "print(f\"success rate {succ.mean():.2f} | FTLE success {np.median(ftle[succ]):+.3f} vs failure {np.median(ftle[~succ]):+.3f}\")\n", + "print(f\"AUC(-FTLE -> success) = {auc(-ftle, succ.astype(int)):.3f} (failure more chaotic)\")\n", + "fig,ax=plt.subplots(1,2,figsize=(11,4))\n", + "b=np.linspace(-0.5,0.75,50)\n", + "ax[0].hist(ftle[succ],b,alpha=.6,color='g',density=True,label='success'); ax[0].hist(ftle[~succ],b,alpha=.6,color='r',density=True,label='failure')\n", + "ax[0].set_title('toy: failure more chaotic'); ax[0].set_xlabel('finite-time Lyapunov exp'); ax[0].legend()\n", + "Ts=[4,8,16,32,64,128,256]; A=[auc(-run_toy(T=T)[0],run_toy(T=T)[1].astype(int)) for T in Ts]; R=[run_toy(T=T)[1].mean() for T in Ts]\n", + "ax[1].plot(Ts,A,'o-',label='AUC(-FTLE->success)'); ax[1].plot(Ts,R,'s--',label='success rate'); ax[1].set_xscale('log')\n", + "ax[1].set_xlabel('readout time T'); ax[1].set_title('finite-time: separation vanishes as T->inf'); ax[1].legend(); plt.tight_layout(); plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "8f3f5192", + "metadata": {}, + "source": [ + "## 2. Load a trained model from HuggingFace\n", + "\n", + "Downloads the model code + checkpoint + config from `YurenHao0426/recursive-reasoning-chaos`. `MODEL` ∈ {`trm_sudoku`, `hrm_sudoku`}." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8972bdd", + "metadata": {}, + "outputs": [], + "source": [ + "import sys, yaml, json\n", + "from pathlib import Path\n", + "from huggingface_hub import snapshot_download\n", + "\n", + "HF_REPO = \"YurenHao0426/recursive-reasoning-chaos\"\n", + "MODEL = \"trm_sudoku\" # or \"hrm_sudoku\"\n", + "root = Path(snapshot_download(HF_REPO))\n", + "# TRM and HRM ship separate `models/` packages -> put the right one on the path.\n", + "# (To switch MODEL, restart the kernel: Python caches the `models` package name.)\n", + "sys.path.insert(0, str(root / (\"code_trm\" if MODEL.startswith(\"trm\") else \"code_hrm\")))\n", + "\n", + "cfg = yaml.safe_load((root / MODEL / \"all_config.yaml\").read_text())\n", + "meta = json.loads((root / \"data\" / \"sudoku_meta.json\").read_text())\n", + "arch = dict(cfg[\"arch\"]); arch.update(batch_size=64, seq_len=meta[\"seq_len\"], vocab_size=meta[\"vocab_size\"],\n", + " num_puzzle_identifiers=meta[\"num_puzzle_identifiers\"], causal=False)\n", + "if MODEL.startswith(\"trm\"):\n", + " from models.recursive_reasoning.trm import TinyRecursiveReasoningModel_ACTV1 as M\n", + "else:\n", + " from models.hrm.hrm_act_v1 import HierarchicalReasoningModel_ACTV1 as M\n", + "model = M(arch)\n", + "sd = torch.load(root / MODEL / \"weights.pt\", map_location=\"cpu\", weights_only=True)\n", + "model.load_state_dict({k.replace(\"_orig_mod.\",\"\").replace(\"model.\",\"\"): v for k,v in sd.items()}, strict=False)\n", + "dev = \"cuda\" if torch.cuda.is_available() else \"cpu\"; model.to(dev).eval()\n", + "inner = model.inner\n", + "inp = np.load(root/\"data\"/\"sudoku_test_inputs.npy\"); lab = np.load(root/\"data\"/\"sudoku_test_labels.npy\")\n", + "pid = np.load(root/\"data\"/\"sudoku_test_pid.npy\")\n", + "print(f\"loaded {MODEL}: hidden={inner.config.hidden_size}, H_cycles={inner.config.H_cycles}, L_cycles={inner.config.L_cycles}, test puzzles={len(inp)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "7d9667d2", + "metadata": {}, + "source": [ + "## 3. Extended rollout — the mechanism\n", + "\n", + "Run the recurrence `N_SEG` segments (far past the 16-segment training budget) and watch the fate of\n", + "trajectories that fail at segment 16. Re-run cell 2 with `MODEL=\"hrm_sudoku\"` to see the contrast." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5143695e", + "metadata": {}, + "outputs": [], + "source": [ + "def extended_rollout(inp, lab, pid, n=256, n_seg=128, seed=0):\n", + " rng=np.random.default_rng(seed); idx=rng.choice(len(inp), n, replace=False)\n", + " pe=inner.puzzle_emb_len; sf=inner.config.seq_len+pe; hid=inner.config.hidden_size\n", + " is_hrm = hasattr(inner, \"H_level\")\n", + " X=torch.tensor(inp[idx].astype(np.int32),device=dev); Y=torch.tensor(lab[idx].astype(np.int32),device=dev)\n", + " P=torch.tensor(pid[idx].astype(np.int32),device=dev)\n", + " EX=[]; DR=[]\n", + " with torch.no_grad():\n", + " zH=inner.H_init.unsqueeze(0).expand(n,sf,hid).clone().to(inner.forward_dtype)\n", + " zL=inner.L_init.unsqueeze(0).expand(n,sf,hid).clone().to(inner.forward_dtype)\n", + " si=dict(cos_sin=inner.rotary_emb() if hasattr(inner,\"rotary_emb\") else None)\n", + " emb=inner._input_embeddings(X,P); m=Y>0; prev=None\n", + " for _ in range(n_seg):\n", + " for _h in range(inner.config.H_cycles):\n", + " for _l in range(inner.config.L_cycles):\n", + " zL=inner.L_level(zL, zH+emb, **si)\n", + " zH=(inner.H_level if is_hrm else inner.L_level)(zH, zL, **si)\n", + " p=inner.lm_head(zH)[:,pe:].float().argmax(-1)\n", + " EX.append(((p==Y)|~m).all(-1).float().cpu().numpy())\n", + " DR.append((torch.zeros(n) if prev is None else (zH-prev).float().flatten(1).norm(1).cpu()).numpy())\n", + " prev=zH.detach()\n", + " return np.stack(EX,1), np.stack(DR,1)\n", + "\n", + "ex, dr = extended_rollout(inp, lab, pid, n=256, n_seg=128)\n", + "T=ex.shape[1]; fail=ex[:,15]==0; nf=fail.sum()\n", + "print(f\"accuracy @16={ex[:,15].mean():.3f} @{T}={ex[:,-1].mean():.3f}\")\n", + "print(f\"of {nf} step-16 failures: self-resolve to CORRECT by seg{T}: {(fail&(ex[:,-1]==1)).sum()/nf*100:.0f}%\")\n", + "ld=dr[:,-4:].mean(1)\n", + "print(f\"median latent drift -- failures {np.median(ld[fail]):.1f} vs successes {np.median(ld[ex[:,15]==1]):.1f}\")\n", + "fig,ax=plt.subplots(1,2,figsize=(11,4))\n", + "ax[0].plot(range(1,T+1), ex.mean(0)); ax[0].axvline(16,ls='--',c='gray'); ax[0].set_xscale('log')\n", + "ax[0].set_xlabel('inference segments'); ax[0].set_ylabel('accuracy'); ax[0].set_title('accuracy vs compute')\n", + "S=[(fail&(ex[:,:s].max(1)==0)).sum()/nf for s in range(16,T+1)]\n", + "ax[1].plot(range(16,T+1),S); ax[1].set_yscale('log'); ax[1].set_xlabel('segments'); ax[1].set_ylabel('frac failures still unsolved')\n", + "ax[1].set_title('escape from chaotic set (straight=transient, plateau=attractor)'); plt.tight_layout(); plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "00681eb1", + "metadata": {}, + "source": [ + "## What this shows\n", + "- **TRM**: accuracy keeps climbing with compute; step-16 failures *escape* the chaotic transient and\n", + " resolve to the correct answer (≈0 settle to a wrong answer). → a chaotic **saddle** + one solution\n", + " fixed point. *More inference compute solves more puzzles.*\n", + "- **HRM**: accuracy plateaus; failures stay **trapped** (latent keeps churning, never escapes). →\n", + " bistability between a stable fixed point (success) and a chaotic **attractor** (failure).\n", + "- Neither settles to a *wrong fixed point* — the \"spurious fixed point\" reading from 2D PCA is an\n", + " artifact of projecting high-dimensional chaotic wandering.\n", + "\n", + "Try: change `MODEL`, `N_SEG`, `eps` (toy); compare TRM vs HRM escape curves." + ] + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} |
