summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--notebooks/build_notebook.py269
-rw-r--r--notebooks/recursive_reasoning_chaos.ipynb311
2 files changed, 321 insertions, 259 deletions
diff --git a/notebooks/build_notebook.py b/notebooks/build_notebook.py
index 0eec83d..8a525b2 100644
--- a/notebooks/build_notebook.py
+++ b/notebooks/build_notebook.py
@@ -1,7 +1,7 @@
"""Build notebooks/recursive_reasoning_chaos.ipynb via nbformat.
-Self-contained playground: (1) analytically-tractable transient-chaos toy (no GPU),
-(2) load a trained TRM/HRM from HuggingFace, (3) extended rollout showing TRM=escapable transient
-chaos vs HRM=trapped chaotic attractor. HF_REPO is filled in after the upload step.
+Order: load a trained TRM/HRM from HuggingFace -> (1) THE core result: leading finite-time Lyapunov
+exponent (lambda_1) along the inference trajectory separates success from failure (failures more
+chaotic) -> (2) why: transient chaos, failures escape with more compute -> (3) basin accessibility.
"""
import nbformat as nbf
from pathlib import Path
@@ -12,72 +12,36 @@ C = []
def md(t): C.append(nbf.v4.new_markdown_cell(t))
def code(t): C.append(nbf.v4.new_code_cell(t))
-md(f"""# Recursive Reasoning Failures are Chaotic — and it's *transient chaos*
+md(f"""# Recursive Reasoning Failures are Chaotic
-Small recursive reasoners (HRM, TRM) iterate a latent state to solve puzzles (Sudoku, Maze).
-Measured along the inference trajectory, **failed examples are more chaotic** (higher finite-time
-Lyapunov exponent / latent drift) than successful ones, in the *same* trained network.
+Small recursive reasoners (HRM, TRM) iterate a latent state to solve a puzzle (Sudoku) before
+emitting an answer. **The core finding:** measured along the recurrent inference trajectory, the
+**leading finite-time Lyapunov exponent (λ₁) is higher on failed examples than on successful ones**
+— in the *same* trained network. Failure is locally more chaotic.
-This notebook lets you reproduce and play with the mechanism:
-1. **Toy model** (pure numpy, no GPU) — *transient chaos*: chaotic search of latent space until the
- trajectory escapes into the solution basin. Failures = not-yet-escaped trajectories.
-2. **Real trained model** loaded from HuggingFace (`{HF_REPO}`).
-3. **Extended rollout** — run the recurrence far beyond its training budget. Both architectures'
- failures sit on a chaotic *saddle* (transient chaos), not a wrong fixed point — they just escape
- at very different rates: **TRM** failures mostly escape and self-correct given enough compute;
- **HRM** failures are far more strongly trapped (most keep churning).
-4. **Basin accessibility** — restart a trapped puzzle from perturbed initial latent states. A small
- kick frees most of TRM's (IC-determined, large basin); a hard core of HRM's never escapes any
- nearby initial condition (input-determined).
+This notebook, top to bottom:
+1. **Load** a trained model from HuggingFace (`{HF_REPO}`).
+2. **The result** — compute λ₁ per example (Benettin / JVP along the trajectory) and show the
+ success-vs-failure separation (histogram + AUC). *This is the headline.*
+3. **Why** — run the recurrence far past its training budget: failures are a *transient* — they
+ escape the chaotic set and self-correct given enough compute (TRM), or stay trapped much longer
+ (HRM). Neither settles to a wrong fixed point.
+4. **Basin accessibility** — restart from a perturbed initial state: is a failure input-determined
+ or initial-condition-determined?
Companion analysis repo: `github.com/YurenHao0426/recursive-reasoning-dynamics`.""")
md("## 0. Setup")
-code("""# minimal deps; torch+einops+pydantic are enough to load these models (TRM-Sudoku is MLP-mixer,
-# no FlashAttention needed -> runs on any GPU, even CPU).
-%pip install -q torch einops pydantic huggingface_hub numpy matplotlib
+code("""%pip install -q torch einops pydantic huggingface_hub numpy matplotlib tqdm
import numpy as np, matplotlib.pyplot as plt, torch
+from tqdm.auto import tqdm
print("torch", torch.__version__, "| cuda", torch.cuda.is_available())""")
-md("""## 1. The toy model — transient chaos (no GPU, runs in seconds)
-
-A trajectory chaotically *searches* `[0,1]` (logistic map, λ=ln2≈+0.69) until it lands within `eps`
-of the solution `s` (the "puzzle"), then it converges (λ=ln0.5<0). At a fixed readout time `T`:
-**captured = success** (FTLE low), **still searching = failure** (FTLE high). The escape time is
-~geometric (chaotic-saddle signature) and the FTLE separation is purely a *finite-time* effect —
-it vanishes as `T→∞` because everyone eventually escapes.""")
-code("""def run_toy(n=20000, T=16, eps=0.04, seed=0):
- rg = np.random.default_rng(seed)
- s = rg.uniform(0.15, 0.85, n); x = rg.uniform(0, 1, n)
- captured = np.zeros(n, bool); logd = np.zeros(n)
- for t in range(T):
- search = ~captured
- ld = np.where(search, np.log(np.abs(4*(1-2*x))+1e-12), np.log(0.5))
- xn = np.where(search, 4*x*(1-x), s + 0.5*(x-s))
- captured |= search & (np.abs(xn-s) < eps); x = xn; logd += ld
- ftle = logd / T
- success = captured & (np.abs(x-s) < 0.05)
- return ftle, success
+md(f"""## 1. Load a trained model from HuggingFace
-def auc(score, y):
- p, n = score[y==1], score[y==0]; a=np.concatenate([p,n]); o=np.argsort(a)
- r=np.empty(len(a)); r[o]=np.arange(1,len(a)+1)
- return (r[:len(p)].sum()-len(p)*(len(p)+1)/2)/(len(p)*len(n))
-
-ftle, succ = run_toy(T=16)
-print(f"success rate {succ.mean():.2f} | FTLE success {np.median(ftle[succ]):+.3f} vs failure {np.median(ftle[~succ]):+.3f}")
-print(f"AUC(-FTLE -> success) = {auc(-ftle, succ.astype(int)):.3f} (failure more chaotic)")
-fig,ax=plt.subplots(1,2,figsize=(11,4))
-b=np.linspace(-0.5,0.75,50)
-ax[0].hist(ftle[succ],b,alpha=.6,color='g',density=True,label='success'); ax[0].hist(ftle[~succ],b,alpha=.6,color='r',density=True,label='failure')
-ax[0].set_title('toy: failure more chaotic'); ax[0].set_xlabel('finite-time Lyapunov exp'); ax[0].legend()
-Ts=[4,8,16,32,64,128,256]; A=[auc(-run_toy(T=T)[0],run_toy(T=T)[1].astype(int)) for T in Ts]; R=[run_toy(T=T)[1].mean() for T in Ts]
-ax[1].plot(Ts,A,'o-',label='AUC(-FTLE->success)'); ax[1].plot(Ts,R,'s--',label='success rate'); ax[1].set_xscale('log')
-ax[1].set_xlabel('readout time T'); ax[1].set_title('finite-time: separation vanishes as T->inf'); ax[1].legend(); plt.tight_layout(); plt.show()""")
-
-md(f"""## 2. Load a trained model from HuggingFace
-
-Downloads the model code + checkpoint + config from `{HF_REPO}`. `MODEL` ∈ {{`trm_sudoku`, `hrm_sudoku`}}.""")
+Downloads model code + checkpoint + a 2000-puzzle test set from `{HF_REPO}`.
+`MODEL` ∈ {{`trm_sudoku`, `hrm_sudoku`}}. **TRM is MLP-only → runs on a laptop CPU.** To switch
+models, change `MODEL` and **restart the kernel** (the two ship same-named `models` packages).""")
code(f"""import sys, yaml, json
from pathlib import Path
from huggingface_hub import snapshot_download
@@ -85,8 +49,6 @@ from huggingface_hub import snapshot_download
HF_REPO = "{HF_REPO}"
MODEL = "trm_sudoku" # or "hrm_sudoku"
root = Path(snapshot_download(HF_REPO))
-# TRM and HRM ship separate `models/` packages -> put the right one on the path.
-# (To switch MODEL, restart the kernel: Python caches the `models` package name.)
sys.path.insert(0, str(root / ("code_trm" if MODEL.startswith("trm") else "code_hrm")))
cfg = yaml.safe_load((root / MODEL / "all_config.yaml").read_text())
@@ -106,14 +68,86 @@ inp = np.load(root/"data"/"sudoku_test_inputs.npy"); lab = np.load(root/"data"/"
pid = np.load(root/"data"/"sudoku_test_pid.npy")
print(f"loaded {{MODEL}}: hidden={{inner.config.hidden_size}}, H_cycles={{inner.config.H_cycles}}, L_cycles={{inner.config.L_cycles}}, test puzzles={{len(inp)}}")""")
-md("""## 3. Extended rollout — the mechanism
+md("""## 2. The core result — failures are more chaotic (leading FTLE / λ₁)
+
+For each puzzle we run the recurrence for the 16-segment inference budget and propagate one tangent
+vector through every module update (forward-mode JVP), renormalizing each step and accumulating the
+log-growth (Benettin's method for the largest exponent). λ₁ = mean log-growth per module-evaluation.
+Then split by outcome at segment 16. **Failures sit at higher λ₁** — they are locally expanding /
+chaotic; successes have collapsed toward the solution. `AUC(−λ₁ → success)` near 1 = clean separation.
+
+The JVP uses the fact that the update is `module(a, b) = layers(a + b)`, so a perturbation of the
+combined input can be fed through one slot. (GPU: ~1 min. Laptop/CPU with TRM: a few min — lower `n`.)""")
+code("""import torch.autograd.functional as AF
+from contextlib import nullcontext
+try: # HRM attention JVP needs the math SDP backend (no FlashAttn double-backward)
+ from torch.nn.attention import sdpa_kernel, SDPBackend
+ MATHCTX = lambda: sdpa_kernel(SDPBackend.MATH)
+except Exception:
+ MATHCTX = nullcontext
-Run the recurrence `N_SEG` segments (far past the 16-segment training budget) and watch the fate of
-trajectories that fail at segment 16. Re-run cell 2 with `MODEL="hrm_sudoku"` to see the contrast.""")
+def auc(score, y):
+ p, n = score[y==1], score[y==0]
+ if len(p)==0 or len(n)==0: return float('nan')
+ a=np.concatenate([p,n]); o=np.argsort(a); r=np.empty(len(a)); r[o]=np.arange(1,len(a)+1)
+ return (r[:len(p)].sum()-len(p)*(len(p)+1)/2)/(len(p)*len(n))
+
+def leading_ftle(inp, lab, pid, n=128, n_seg=16, seed=0):
+ rng=np.random.default_rng(seed); idx=rng.choice(len(inp), n, replace=False)
+ pe=inner.puzzle_emb_len; sf=inner.config.seq_len+pe; hid=inner.config.hidden_size; D=sf*hid; B=n
+ is_hrm=hasattr(inner,"H_level") and getattr(inner,"H_level",None) is not None
+ Hmod=inner.H_level if is_hrm else inner.L_level # weight-tied TRM reuses L_level
+ X=torch.tensor(inp[idx].astype(np.int32),device=dev); Y=torch.tensor(lab[idx].astype(np.int32),device=dev)
+ P=torch.tensor(pid[idx].astype(np.int32),device=dev)
+ si=dict(cos_sin=inner.rotary_emb() if hasattr(inner,"rotary_emb") else None)
+ g=torch.Generator(device=dev).manual_seed(seed)
+ jvp=lambda f,x,v: AF.jvp(f, x, v=v, create_graph=False, strict=False)
+ def renorm(vH,vL):
+ nrm=torch.sqrt(vH.pow(2).sum(1,keepdim=True)+vL.pow(2).sum(1,keepdim=True)).clamp_min(1e-30)
+ return vH/nrm, vL/nrm, nrm.squeeze(1)
+ with MATHCTX():
+ emb=inner._input_embeddings(X,P); m=Y>0
+ zH=inner.H_init.unsqueeze(0).expand(B,sf,hid).clone().to(inner.forward_dtype)
+ zL=inner.L_init.unsqueeze(0).expand(B,sf,hid).clone().to(inner.forward_dtype)
+ vH=torch.randn(B,D,device=dev,generator=g); vL=torch.randn(B,D,device=dev,generator=g)
+ vH,vL,_=renorm(vH,vL); logsum=torch.zeros(B,device=dev); nstep=0
+ for seg in tqdm(range(n_seg), desc="FTLE (segments)"):
+ with torch.enable_grad():
+ zH,zL=zH.detach(),zL.detach()
+ for _h in range(inner.config.H_cycles):
+ for _l in range(inner.config.L_cycles):
+ vc=(vH+vL).reshape(B,sf,hid).to(inner.forward_dtype)
+ zL,Dv=jvp(lambda z: inner.L_level(z, zH+emb, **si), zL, vc)
+ vL=Dv.reshape(B,D).float(); vH,vL,grow=renorm(vH,vL); logsum+=grow.log(); nstep+=1
+ vc=(vH+vL).reshape(B,sf,hid).to(inner.forward_dtype)
+ zH,Dv=jvp(lambda z: Hmod(z, zL, **si), zH, vc)
+ vH=Dv.reshape(B,D).float(); vH,vL,grow=renorm(vH,vL); logsum+=grow.log(); nstep+=1
+ ftle=(logsum/nstep).cpu().numpy()
+ ok=(((inner.lm_head(zH)[:,pe:].float().argmax(-1)==Y)|~m).all(-1)).cpu().numpy()
+ return ftle, ok
+
+ftle, succ = leading_ftle(inp, lab, pid, n=128)
+print(f"success rate {succ.mean():.2f} | median λ1 success {np.median(ftle[succ]):+.4f} vs failure {np.median(ftle[~succ]):+.4f}")
+print(f"AUC(-λ1 -> success) = {auc(-ftle, succ.astype(int)):.3f} (>0.5 means failures are more chaotic)")
+plt.figure(figsize=(6,4))
+b=np.linspace(ftle.min(), ftle.max(), 40)
+plt.hist(ftle[succ], b, alpha=.6, color='g', density=True, label=f'success (n={succ.sum()})')
+plt.hist(ftle[~succ], b, alpha=.6, color='r', density=True, label=f'failure (n={(~succ).sum()})')
+plt.axvline(0, ls=':', c='k', lw=1)
+plt.xlabel('leading finite-time Lyapunov exponent λ1'); plt.ylabel('density')
+plt.title(f'{MODEL}: failures are more chaotic'); plt.legend(); plt.tight_layout(); plt.show()""")
+
+md("""## 3. Why — transient chaos: failures *escape* with more compute
+
+Run the recurrence `N_SEG` segments (far past the 16-segment budget) and watch the fate of
+trajectories that fail at segment 16. **TRM** failures escape the chaotic transient and resolve to
+the correct answer; **HRM** failures are far more strongly trapped. Re-run cell 1 with
+`MODEL="hrm_sudoku"` (restart kernel) to compare.""")
code("""def extended_rollout(inp, lab, pid, n=256, n_seg=128, seed=0):
rng=np.random.default_rng(seed); idx=rng.choice(len(inp), n, replace=False)
pe=inner.puzzle_emb_len; sf=inner.config.seq_len+pe; hid=inner.config.hidden_size
- is_hrm = hasattr(inner, "H_level")
+ is_hrm=hasattr(inner, "H_level") and getattr(inner,"H_level",None) is not None
+ Hmod=inner.H_level if is_hrm else inner.L_level
X=torch.tensor(inp[idx].astype(np.int32),device=dev); Y=torch.tensor(lab[idx].astype(np.int32),device=dev)
P=torch.tensor(pid[idx].astype(np.int32),device=dev)
EX=[]; DR=[]
@@ -122,11 +156,11 @@ code("""def extended_rollout(inp, lab, pid, n=256, n_seg=128, seed=0):
zL=inner.L_init.unsqueeze(0).expand(n,sf,hid).clone().to(inner.forward_dtype)
si=dict(cos_sin=inner.rotary_emb() if hasattr(inner,"rotary_emb") else None)
emb=inner._input_embeddings(X,P); m=Y>0; prev=None
- for _ in range(n_seg):
+ for _ in tqdm(range(n_seg), desc="rollout (segments)"):
for _h in range(inner.config.H_cycles):
for _l in range(inner.config.L_cycles):
zL=inner.L_level(zL, zH+emb, **si)
- zH=(inner.H_level if is_hrm else inner.L_level)(zH, zL, **si)
+ zH=Hmod(zH, zL, **si)
p=inner.lm_head(zH)[:,pe:].float().argmax(-1)
EX.append(((p==Y)|~m).all(-1).float().cpu().numpy())
DR.append((torch.zeros(n) if prev is None else (zH-prev).float().flatten(1).norm(1).cpu()).numpy())
@@ -136,77 +170,74 @@ code("""def extended_rollout(inp, lab, pid, n=256, n_seg=128, seed=0):
ex, dr = extended_rollout(inp, lab, pid, n=256, n_seg=128)
T=ex.shape[1]; fail=ex[:,15]==0; nf=fail.sum()
print(f"accuracy @16={ex[:,15].mean():.3f} @{T}={ex[:,-1].mean():.3f}")
-print(f"of {nf} step-16 failures: self-resolve to CORRECT by seg{T}: {(fail&(ex[:,-1]==1)).sum()/nf*100:.0f}%")
-ld=dr[:,-4:].mean(1)
-print(f"median latent drift -- failures {np.median(ld[fail]):.1f} vs successes {np.median(ld[ex[:,15]==1]):.1f}")
+print(f"of {nf} step-16 failures: self-resolve to CORRECT by seg{T}: {(fail&(ex[:,-1]==1)).sum()/max(nf,1)*100:.0f}%")
fig,ax=plt.subplots(1,2,figsize=(11,4))
ax[0].plot(range(1,T+1), ex.mean(0)); ax[0].axvline(16,ls='--',c='gray'); ax[0].set_xscale('log')
ax[0].set_xlabel('inference segments'); ax[0].set_ylabel('accuracy'); ax[0].set_title('accuracy vs compute')
-S=[(fail&(ex[:,:s].max(1)==0)).sum()/nf for s in range(16,T+1)]
+S=[(fail&(ex[:,:s].max(1)==0)).sum()/max(nf,1) for s in range(16,T+1)]
ax[1].plot(range(16,T+1),S); ax[1].set_yscale('log'); ax[1].set_xlabel('segments'); ax[1].set_ylabel('frac failures still unsolved')
ax[1].set_title('escape from chaotic set (straight line on log-y = exponential escape)'); plt.tight_layout(); plt.show()""")
md("""## 4. Basin accessibility — input-determined or initial-condition-determined?
-The puzzle is re-injected at *every* segment (`z_H + input_embeddings`), so perturbing only the
-**initial** latent state `z0` is a clean initial-condition change that leaves the input intact.
-Restart each step-16 failure `K` times from `z0 + sigma*noise`: if a small kick frees it (some
-restart solves), the solution basin is large and accessible — *initial-condition-determined*; if no
-nearby IC escapes, the trapping is *input-determined*. TRM: a small kick frees most. HRM: a hard
-core escapes no nearby IC. (GPU: seconds. Laptop/CPU with TRM: a couple of minutes — lower `n`/`K`.)""")
+The puzzle is re-injected every segment (`z_H + input_embeddings`), so perturbing only the
+**initial** latent `z0` is a clean initial-condition change that leaves the input intact. Restart
+each step-16 failure `K` times from `z0 + sigma*noise`: if a small kick frees it, the solution basin
+is large and accessible (TRM); if no nearby IC escapes, the trapping is input-determined (HRM has a
+hard core). (GPU: seconds. Laptop/CPU with TRM: a couple of minutes — lower `n`/`K`.)""")
code("""def perturb_z0(inp, lab, pid, n=96, K=8, sigmas=(0.0, 0.1, 0.3, 1.0), n_seg=48, readout=16, seed=0):
- rng = np.random.default_rng(seed); idx = rng.choice(len(inp), n, replace=False)
- pe = inner.puzzle_emb_len; sf = inner.config.seq_len + pe; hid = inner.config.hidden_size
- is_hrm = hasattr(inner, "H_level") and getattr(inner, "H_level", None) is not None
- Hup = inner.H_level if is_hrm else inner.L_level # weight-tied TRM reuses L_level
- sc = float(inner.H_init.float().std()); g = torch.Generator(device=dev).manual_seed(seed)
- X = torch.tensor(inp[idx].astype(np.int32), device=dev); Y = torch.tensor(lab[idx].astype(np.int32), device=dev)
- P = torch.tensor(pid[idx].astype(np.int32), device=dev)
- si = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None)
- solve = np.zeros((n, len(sigmas), K), bool); base = None
+ rng=np.random.default_rng(seed); idx=rng.choice(len(inp), n, replace=False)
+ pe=inner.puzzle_emb_len; sf=inner.config.seq_len+pe; hid=inner.config.hidden_size
+ is_hrm=hasattr(inner,"H_level") and getattr(inner,"H_level",None) is not None
+ Hmod=inner.H_level if is_hrm else inner.L_level
+ sc=float(inner.H_init.float().std()); g=torch.Generator(device=dev).manual_seed(seed)
+ X=torch.tensor(inp[idx].astype(np.int32),device=dev); Y=torch.tensor(lab[idx].astype(np.int32),device=dev)
+ P=torch.tensor(pid[idx].astype(np.int32),device=dev)
+ si=dict(cos_sin=inner.rotary_emb() if hasattr(inner,"rotary_emb") else None)
+ solve=np.zeros((n,len(sigmas),K),bool); base=None
with torch.no_grad():
- emb0 = inner._input_embeddings(X, P); m0 = Y > 0
- for sj, sg in enumerate(sigmas):
- emb = emb0.repeat_interleave(K, 0); Yr = Y.repeat_interleave(K, 0); mr = m0.repeat_interleave(K, 0); B = n * K
- zH = inner.H_init.unsqueeze(0).expand(B, sf, hid).clone().to(inner.forward_dtype)
- zL = inner.L_init.unsqueeze(0).expand(B, sf, hid).clone().to(inner.forward_dtype)
- if sg > 0:
- zH = (zH.float() + sg*sc*torch.randn(zH.shape, generator=g, device=dev)).to(inner.forward_dtype)
- zL = (zL.float() + sg*sc*torch.randn(zL.shape, generator=g, device=dev)).to(inner.forward_dtype)
- solved = torch.zeros(B, dtype=torch.bool, device=dev)
+ emb0=inner._input_embeddings(X,P); m0=Y>0
+ for sj,sg in enumerate(tqdm(sigmas, desc="basin (sigma levels)")):
+ emb=emb0.repeat_interleave(K,0); Yr=Y.repeat_interleave(K,0); mr=m0.repeat_interleave(K,0); B=n*K
+ zH=inner.H_init.unsqueeze(0).expand(B,sf,hid).clone().to(inner.forward_dtype)
+ zL=inner.L_init.unsqueeze(0).expand(B,sf,hid).clone().to(inner.forward_dtype)
+ if sg>0:
+ zH=(zH.float()+sg*sc*torch.randn(zH.shape,generator=g,device=dev)).to(inner.forward_dtype)
+ zL=(zL.float()+sg*sc*torch.randn(zL.shape,generator=g,device=dev)).to(inner.forward_dtype)
+ solved=torch.zeros(B,dtype=torch.bool,device=dev)
for s in range(n_seg):
for _h in range(inner.config.H_cycles):
- for _l in range(inner.config.L_cycles): zL = inner.L_level(zL, zH + emb, **si)
- zH = Hup(zH, zL, **si)
- ok = ((inner.lm_head(zH)[:, pe:].float().argmax(-1) == Yr) | ~mr).all(-1); solved |= ok
- if sj == 0 and s == readout - 1: base = ok.view(n, K)[:, 0].cpu().numpy()
- solve[:, sj] = solved.view(n, K).cpu().numpy()
+ for _l in range(inner.config.L_cycles): zL=inner.L_level(zL,zH+emb,**si)
+ zH=Hmod(zH,zL,**si)
+ ok=((inner.lm_head(zH)[:,pe:].float().argmax(-1)==Yr)|~mr).all(-1); solved|=ok
+ if sj==0 and s==readout-1: base=ok.view(n,K)[:,0].cpu().numpy()
+ solve[:,sj]=solved.view(n,K).cpu().numpy()
return solve, base, np.array(sigmas)
solve, base, sg = perturb_z0(inp, lab, pid)
-fail = ~base; nf = int(fail.sum())
-print(f"{nf} of {len(base)} puzzles fail@{16}; freeing them by restarting from a perturbed IC:")
-for j, s in enumerate(sg):
- sub = solve[fail, j]; print(f" sigma={s:.1f}: single-restart={sub.mean():.2f} best-of-K={sub.any(1).mean():.2f}")
-plt.figure(figsize=(6, 4))
-plt.plot(sg, [solve[fail, j].mean() for j in range(len(sg))], 'o--', label='single restart')
-plt.plot(sg, [solve[fail, j].any(1).mean() for j in range(len(sg))], 's-', label='best-of-K')
+fail=~base; nf=int(fail.sum())
+print(f"{nf} of {len(base)} puzzles fail@16; freeing them by restarting from a perturbed IC:")
+for j,s in enumerate(sg):
+ sub=solve[fail,j]; print(f" sigma={s:.1f}: single-restart={sub.mean():.2f} best-of-K={sub.any(1).mean():.2f}")
+plt.figure(figsize=(6,4))
+plt.plot(sg,[solve[fail,j].mean() for j in range(len(sg))],'o--',label='single restart')
+plt.plot(sg,[solve[fail,j].any(1).mean() for j in range(len(sg))],'s-',label='best-of-K')
plt.xlabel('relative IC noise sigma'); plt.ylabel('solve rate (failing puzzles)')
plt.title('basin accessibility: does a restart free a trapped puzzle?'); plt.legend(); plt.grid(alpha=.3); plt.show()""")
md("""## What this shows
-- **TRM**: step-16 failures *escape* the chaotic transient and resolve to the correct answer
- (≈0 settle to a wrong answer) → a chaotic **saddle** + one solution fixed point. More compute
- solves more puzzles.
-- **HRM**: failures escape too, but *much* more slowly — most are still churning at this horizon.
- Out to 4000 segments the never-correct fraction keeps decaying (≈0.87→0.77), so it is a
- **strongly-trapping chaotic saddle**, NOT a strict attractor. And the per-segment escape-rate gap
- (~5×) is mostly compute-per-segment: TRM evaluates its recurrent module 21×/segment vs HRM 6×, so
- per module-evaluation the gap is only ~1.6×.
-- **Neither settles to a wrong fixed point** — the "spurious fixed point" reading from 2D PCA is an
- artifact of projecting high-dimensional chaotic wandering onto two axes.
-
-Try: change `MODEL`, `N_SEG`, `eps` (toy); compare TRM vs HRM escape curves.""")
+- **The result (cell 2):** in the same trained network, failed trajectories have a higher leading
+ finite-time Lyapunov exponent than successful ones — failure is locally more chaotic.
+- **Why (cell 3):** that chaos is a *transient*. Failures sit on a chaotic **saddle**, not a wrong
+ fixed point — TRM's escape and self-correct with more compute; HRM's are much more strongly
+ trapped (still a saddle, just a far smaller escape rate). The per-segment escape gap is mostly
+ compute-per-segment (TRM evaluates its module 21×/segment vs HRM 6×; per module-eval the gap is
+ only ~1.6×). The "spurious fixed point" reading from 2D PCA is an artifact of projecting
+ high-dimensional chaotic wandering.
+- **Basin (cell 4):** a small initial-condition kick frees most of TRM's trapped puzzles
+ (IC-determined, large basin); a hard core of HRM's escapes no nearby IC (input-determined).
+
+Try: change `MODEL` (restart kernel), `n`/`n_seg`, and compare TRM vs HRM at every step.""")
nb["cells"] = C
out = Path(__file__).resolve().parent / "recursive_reasoning_chaos.ipynb"
diff --git a/notebooks/recursive_reasoning_chaos.ipynb b/notebooks/recursive_reasoning_chaos.ipynb
index 097f9b6..1f00628 100644
--- a/notebooks/recursive_reasoning_chaos.ipynb
+++ b/notebooks/recursive_reasoning_chaos.ipynb
@@ -2,33 +2,32 @@
"cells": [
{
"cell_type": "markdown",
- "id": "6c32c5e8",
+ "id": "f991b29d",
"metadata": {},
"source": [
- "# Recursive Reasoning Failures are Chaotic — and it's *transient chaos*\n",
+ "# Recursive Reasoning Failures are Chaotic\n",
"\n",
- "Small recursive reasoners (HRM, TRM) iterate a latent state to solve puzzles (Sudoku, Maze).\n",
- "Measured along the inference trajectory, **failed examples are more chaotic** (higher finite-time\n",
- "Lyapunov exponent / latent drift) than successful ones, in the *same* trained network.\n",
+ "Small recursive reasoners (HRM, TRM) iterate a latent state to solve a puzzle (Sudoku) before\n",
+ "emitting an answer. **The core finding:** measured along the recurrent inference trajectory, the\n",
+ "**leading finite-time Lyapunov exponent (λ₁) is higher on failed examples than on successful ones**\n",
+ "— in the *same* trained network. Failure is locally more chaotic.\n",
"\n",
- "This notebook lets you reproduce and play with the mechanism:\n",
- "1. **Toy model** (pure numpy, no GPU) — *transient chaos*: chaotic search of latent space until the\n",
- " trajectory escapes into the solution basin. Failures = not-yet-escaped trajectories.\n",
- "2. **Real trained model** loaded from HuggingFace (`blackhao0426/recursive-reasoning-chaos`).\n",
- "3. **Extended rollout** — run the recurrence far beyond its training budget. Both architectures'\n",
- " failures sit on a chaotic *saddle* (transient chaos), not a wrong fixed point — they just escape\n",
- " at very different rates: **TRM** failures mostly escape and self-correct given enough compute;\n",
- " **HRM** failures are far more strongly trapped (most keep churning).\n",
- "4. **Basin accessibility** — restart a trapped puzzle from perturbed initial latent states. A small\n",
- " kick frees most of TRM's (IC-determined, large basin); a hard core of HRM's never escapes any\n",
- " nearby initial condition (input-determined).\n",
+ "This notebook, top to bottom:\n",
+ "1. **Load** a trained model from HuggingFace (`blackhao0426/recursive-reasoning-chaos`).\n",
+ "2. **The result** — compute λ₁ per example (Benettin / JVP along the trajectory) and show the\n",
+ " success-vs-failure separation (histogram + AUC). *This is the headline.*\n",
+ "3. **Why** — run the recurrence far past its training budget: failures are a *transient* — they\n",
+ " escape the chaotic set and self-correct given enough compute (TRM), or stay trapped much longer\n",
+ " (HRM). Neither settles to a wrong fixed point.\n",
+ "4. **Basin accessibility** — restart from a perturbed initial state: is a failure input-determined\n",
+ " or initial-condition-determined?\n",
"\n",
"Companion analysis repo: `github.com/YurenHao0426/recursive-reasoning-dynamics`."
]
},
{
"cell_type": "markdown",
- "id": "9161f5cf",
+ "id": "bb82e0a8",
"metadata": {},
"source": [
"## 0. Setup"
@@ -37,82 +36,32 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2034b179",
+ "id": "89a96776",
"metadata": {},
"outputs": [],
"source": [
- "# minimal deps; torch+einops+pydantic are enough to load these models (TRM-Sudoku is MLP-mixer,\n",
- "# no FlashAttention needed -> runs on any GPU, even CPU).\n",
- "%pip install -q torch einops pydantic huggingface_hub numpy matplotlib\n",
+ "%pip install -q torch einops pydantic huggingface_hub numpy matplotlib tqdm\n",
"import numpy as np, matplotlib.pyplot as plt, torch\n",
+ "from tqdm.auto import tqdm\n",
"print(\"torch\", torch.__version__, \"| cuda\", torch.cuda.is_available())"
]
},
{
"cell_type": "markdown",
- "id": "fe43cb07",
+ "id": "bbd25841",
"metadata": {},
"source": [
- "## 1. The toy model — transient chaos (no GPU, runs in seconds)\n",
+ "## 1. Load a trained model from HuggingFace\n",
"\n",
- "A trajectory chaotically *searches* `[0,1]` (logistic map, λ=ln2≈+0.69) until it lands within `eps`\n",
- "of the solution `s` (the \"puzzle\"), then it converges (λ=ln0.5<0). At a fixed readout time `T`:\n",
- "**captured = success** (FTLE low), **still searching = failure** (FTLE high). The escape time is\n",
- "~geometric (chaotic-saddle signature) and the FTLE separation is purely a *finite-time* effect —\n",
- "it vanishes as `T→∞` because everyone eventually escapes."
+ "Downloads model code + checkpoint + a 2000-puzzle test set from `blackhao0426/recursive-reasoning-chaos`.\n",
+ "`MODEL` ∈ {`trm_sudoku`, `hrm_sudoku`}. **TRM is MLP-only → runs on a laptop CPU.** To switch\n",
+ "models, change `MODEL` and **restart the kernel** (the two ship same-named `models` packages)."
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "6593c881",
- "metadata": {},
- "outputs": [],
- "source": [
- "def run_toy(n=20000, T=16, eps=0.04, seed=0):\n",
- " rg = np.random.default_rng(seed)\n",
- " s = rg.uniform(0.15, 0.85, n); x = rg.uniform(0, 1, n)\n",
- " captured = np.zeros(n, bool); logd = np.zeros(n)\n",
- " for t in range(T):\n",
- " search = ~captured\n",
- " ld = np.where(search, np.log(np.abs(4*(1-2*x))+1e-12), np.log(0.5))\n",
- " xn = np.where(search, 4*x*(1-x), s + 0.5*(x-s))\n",
- " captured |= search & (np.abs(xn-s) < eps); x = xn; logd += ld\n",
- " ftle = logd / T\n",
- " success = captured & (np.abs(x-s) < 0.05)\n",
- " return ftle, success\n",
- "\n",
- "def auc(score, y):\n",
- " p, n = score[y==1], score[y==0]; a=np.concatenate([p,n]); o=np.argsort(a)\n",
- " r=np.empty(len(a)); r[o]=np.arange(1,len(a)+1)\n",
- " return (r[:len(p)].sum()-len(p)*(len(p)+1)/2)/(len(p)*len(n))\n",
- "\n",
- "ftle, succ = run_toy(T=16)\n",
- "print(f\"success rate {succ.mean():.2f} | FTLE success {np.median(ftle[succ]):+.3f} vs failure {np.median(ftle[~succ]):+.3f}\")\n",
- "print(f\"AUC(-FTLE -> success) = {auc(-ftle, succ.astype(int)):.3f} (failure more chaotic)\")\n",
- "fig,ax=plt.subplots(1,2,figsize=(11,4))\n",
- "b=np.linspace(-0.5,0.75,50)\n",
- "ax[0].hist(ftle[succ],b,alpha=.6,color='g',density=True,label='success'); ax[0].hist(ftle[~succ],b,alpha=.6,color='r',density=True,label='failure')\n",
- "ax[0].set_title('toy: failure more chaotic'); ax[0].set_xlabel('finite-time Lyapunov exp'); ax[0].legend()\n",
- "Ts=[4,8,16,32,64,128,256]; A=[auc(-run_toy(T=T)[0],run_toy(T=T)[1].astype(int)) for T in Ts]; R=[run_toy(T=T)[1].mean() for T in Ts]\n",
- "ax[1].plot(Ts,A,'o-',label='AUC(-FTLE->success)'); ax[1].plot(Ts,R,'s--',label='success rate'); ax[1].set_xscale('log')\n",
- "ax[1].set_xlabel('readout time T'); ax[1].set_title('finite-time: separation vanishes as T->inf'); ax[1].legend(); plt.tight_layout(); plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "aee64679",
- "metadata": {},
- "source": [
- "## 2. Load a trained model from HuggingFace\n",
- "\n",
- "Downloads the model code + checkpoint + config from `blackhao0426/recursive-reasoning-chaos`. `MODEL` ∈ {`trm_sudoku`, `hrm_sudoku`}."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5f5f69ff",
+ "id": "f6a83ba0",
"metadata": {},
"outputs": [],
"source": [
@@ -123,8 +72,6 @@
"HF_REPO = \"blackhao0426/recursive-reasoning-chaos\"\n",
"MODEL = \"trm_sudoku\" # or \"hrm_sudoku\"\n",
"root = Path(snapshot_download(HF_REPO))\n",
- "# TRM and HRM ship separate `models/` packages -> put the right one on the path.\n",
- "# (To switch MODEL, restart the kernel: Python caches the `models` package name.)\n",
"sys.path.insert(0, str(root / (\"code_trm\" if MODEL.startswith(\"trm\") else \"code_hrm\")))\n",
"\n",
"cfg = yaml.safe_load((root / MODEL / \"all_config.yaml\").read_text())\n",
@@ -147,26 +94,113 @@
},
{
"cell_type": "markdown",
- "id": "cea384ce",
+ "id": "086e411f",
"metadata": {},
"source": [
- "## 3. Extended rollout — the mechanism\n",
+ "## 2. The core result — failures are more chaotic (leading FTLE / λ₁)\n",
+ "\n",
+ "For each puzzle we run the recurrence for the 16-segment inference budget and propagate one tangent\n",
+ "vector through every module update (forward-mode JVP), renormalizing each step and accumulating the\n",
+ "log-growth (Benettin's method for the largest exponent). λ₁ = mean log-growth per module-evaluation.\n",
+ "Then split by outcome at segment 16. **Failures sit at higher λ₁** — they are locally expanding /\n",
+ "chaotic; successes have collapsed toward the solution. `AUC(−λ₁ → success)` near 1 = clean separation.\n",
"\n",
- "Run the recurrence `N_SEG` segments (far past the 16-segment training budget) and watch the fate of\n",
- "trajectories that fail at segment 16. Re-run cell 2 with `MODEL=\"hrm_sudoku\"` to see the contrast."
+ "The JVP uses the fact that the update is `module(a, b) = layers(a + b)`, so a perturbation of the\n",
+ "combined input can be fed through one slot. (GPU: ~1 min. Laptop/CPU with TRM: a few min — lower `n`.)"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "5d7ec0ce",
+ "id": "791d02dc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch.autograd.functional as AF\n",
+ "from contextlib import nullcontext\n",
+ "try: # HRM attention JVP needs the math SDP backend (no FlashAttn double-backward)\n",
+ " from torch.nn.attention import sdpa_kernel, SDPBackend\n",
+ " MATHCTX = lambda: sdpa_kernel(SDPBackend.MATH)\n",
+ "except Exception:\n",
+ " MATHCTX = nullcontext\n",
+ "\n",
+ "def auc(score, y):\n",
+ " p, n = score[y==1], score[y==0]\n",
+ " if len(p)==0 or len(n)==0: return float('nan')\n",
+ " a=np.concatenate([p,n]); o=np.argsort(a); r=np.empty(len(a)); r[o]=np.arange(1,len(a)+1)\n",
+ " return (r[:len(p)].sum()-len(p)*(len(p)+1)/2)/(len(p)*len(n))\n",
+ "\n",
+ "def leading_ftle(inp, lab, pid, n=128, n_seg=16, seed=0):\n",
+ " rng=np.random.default_rng(seed); idx=rng.choice(len(inp), n, replace=False)\n",
+ " pe=inner.puzzle_emb_len; sf=inner.config.seq_len+pe; hid=inner.config.hidden_size; D=sf*hid; B=n\n",
+ " is_hrm=hasattr(inner,\"H_level\") and getattr(inner,\"H_level\",None) is not None\n",
+ " Hmod=inner.H_level if is_hrm else inner.L_level # weight-tied TRM reuses L_level\n",
+ " X=torch.tensor(inp[idx].astype(np.int32),device=dev); Y=torch.tensor(lab[idx].astype(np.int32),device=dev)\n",
+ " P=torch.tensor(pid[idx].astype(np.int32),device=dev)\n",
+ " si=dict(cos_sin=inner.rotary_emb() if hasattr(inner,\"rotary_emb\") else None)\n",
+ " g=torch.Generator(device=dev).manual_seed(seed)\n",
+ " jvp=lambda f,x,v: AF.jvp(f, x, v=v, create_graph=False, strict=False)\n",
+ " def renorm(vH,vL):\n",
+ " nrm=torch.sqrt(vH.pow(2).sum(1,keepdim=True)+vL.pow(2).sum(1,keepdim=True)).clamp_min(1e-30)\n",
+ " return vH/nrm, vL/nrm, nrm.squeeze(1)\n",
+ " with MATHCTX():\n",
+ " emb=inner._input_embeddings(X,P); m=Y>0\n",
+ " zH=inner.H_init.unsqueeze(0).expand(B,sf,hid).clone().to(inner.forward_dtype)\n",
+ " zL=inner.L_init.unsqueeze(0).expand(B,sf,hid).clone().to(inner.forward_dtype)\n",
+ " vH=torch.randn(B,D,device=dev,generator=g); vL=torch.randn(B,D,device=dev,generator=g)\n",
+ " vH,vL,_=renorm(vH,vL); logsum=torch.zeros(B,device=dev); nstep=0\n",
+ " for seg in tqdm(range(n_seg), desc=\"FTLE (segments)\"):\n",
+ " with torch.enable_grad():\n",
+ " zH,zL=zH.detach(),zL.detach()\n",
+ " for _h in range(inner.config.H_cycles):\n",
+ " for _l in range(inner.config.L_cycles):\n",
+ " vc=(vH+vL).reshape(B,sf,hid).to(inner.forward_dtype)\n",
+ " zL,Dv=jvp(lambda z: inner.L_level(z, zH+emb, **si), zL, vc)\n",
+ " vL=Dv.reshape(B,D).float(); vH,vL,grow=renorm(vH,vL); logsum+=grow.log(); nstep+=1\n",
+ " vc=(vH+vL).reshape(B,sf,hid).to(inner.forward_dtype)\n",
+ " zH,Dv=jvp(lambda z: Hmod(z, zL, **si), zH, vc)\n",
+ " vH=Dv.reshape(B,D).float(); vH,vL,grow=renorm(vH,vL); logsum+=grow.log(); nstep+=1\n",
+ " ftle=(logsum/nstep).cpu().numpy()\n",
+ " ok=(((inner.lm_head(zH)[:,pe:].float().argmax(-1)==Y)|~m).all(-1)).cpu().numpy()\n",
+ " return ftle, ok\n",
+ "\n",
+ "ftle, succ = leading_ftle(inp, lab, pid, n=128)\n",
+ "print(f\"success rate {succ.mean():.2f} | median λ1 success {np.median(ftle[succ]):+.4f} vs failure {np.median(ftle[~succ]):+.4f}\")\n",
+ "print(f\"AUC(-λ1 -> success) = {auc(-ftle, succ.astype(int)):.3f} (>0.5 means failures are more chaotic)\")\n",
+ "plt.figure(figsize=(6,4))\n",
+ "b=np.linspace(ftle.min(), ftle.max(), 40)\n",
+ "plt.hist(ftle[succ], b, alpha=.6, color='g', density=True, label=f'success (n={succ.sum()})')\n",
+ "plt.hist(ftle[~succ], b, alpha=.6, color='r', density=True, label=f'failure (n={(~succ).sum()})')\n",
+ "plt.axvline(0, ls=':', c='k', lw=1)\n",
+ "plt.xlabel('leading finite-time Lyapunov exponent λ1'); plt.ylabel('density')\n",
+ "plt.title(f'{MODEL}: failures are more chaotic'); plt.legend(); plt.tight_layout(); plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "25aa2620",
+ "metadata": {},
+ "source": [
+ "## 3. Why — transient chaos: failures *escape* with more compute\n",
+ "\n",
+ "Run the recurrence `N_SEG` segments (far past the 16-segment budget) and watch the fate of\n",
+ "trajectories that fail at segment 16. **TRM** failures escape the chaotic transient and resolve to\n",
+ "the correct answer; **HRM** failures are far more strongly trapped. Re-run cell 1 with\n",
+ "`MODEL=\"hrm_sudoku\"` (restart kernel) to compare."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "106f13d3",
"metadata": {},
"outputs": [],
"source": [
"def extended_rollout(inp, lab, pid, n=256, n_seg=128, seed=0):\n",
" rng=np.random.default_rng(seed); idx=rng.choice(len(inp), n, replace=False)\n",
" pe=inner.puzzle_emb_len; sf=inner.config.seq_len+pe; hid=inner.config.hidden_size\n",
- " is_hrm = hasattr(inner, \"H_level\")\n",
+ " is_hrm=hasattr(inner, \"H_level\") and getattr(inner,\"H_level\",None) is not None\n",
+ " Hmod=inner.H_level if is_hrm else inner.L_level\n",
" X=torch.tensor(inp[idx].astype(np.int32),device=dev); Y=torch.tensor(lab[idx].astype(np.int32),device=dev)\n",
" P=torch.tensor(pid[idx].astype(np.int32),device=dev)\n",
" EX=[]; DR=[]\n",
@@ -175,11 +209,11 @@
" zL=inner.L_init.unsqueeze(0).expand(n,sf,hid).clone().to(inner.forward_dtype)\n",
" si=dict(cos_sin=inner.rotary_emb() if hasattr(inner,\"rotary_emb\") else None)\n",
" emb=inner._input_embeddings(X,P); m=Y>0; prev=None\n",
- " for _ in range(n_seg):\n",
+ " for _ in tqdm(range(n_seg), desc=\"rollout (segments)\"):\n",
" for _h in range(inner.config.H_cycles):\n",
" for _l in range(inner.config.L_cycles):\n",
" zL=inner.L_level(zL, zH+emb, **si)\n",
- " zH=(inner.H_level if is_hrm else inner.L_level)(zH, zL, **si)\n",
+ " zH=Hmod(zH, zL, **si)\n",
" p=inner.lm_head(zH)[:,pe:].float().argmax(-1)\n",
" EX.append(((p==Y)|~m).all(-1).float().cpu().numpy())\n",
" DR.append((torch.zeros(n) if prev is None else (zH-prev).float().flatten(1).norm(1).cpu()).numpy())\n",
@@ -189,98 +223,95 @@
"ex, dr = extended_rollout(inp, lab, pid, n=256, n_seg=128)\n",
"T=ex.shape[1]; fail=ex[:,15]==0; nf=fail.sum()\n",
"print(f\"accuracy @16={ex[:,15].mean():.3f} @{T}={ex[:,-1].mean():.3f}\")\n",
- "print(f\"of {nf} step-16 failures: self-resolve to CORRECT by seg{T}: {(fail&(ex[:,-1]==1)).sum()/nf*100:.0f}%\")\n",
- "ld=dr[:,-4:].mean(1)\n",
- "print(f\"median latent drift -- failures {np.median(ld[fail]):.1f} vs successes {np.median(ld[ex[:,15]==1]):.1f}\")\n",
+ "print(f\"of {nf} step-16 failures: self-resolve to CORRECT by seg{T}: {(fail&(ex[:,-1]==1)).sum()/max(nf,1)*100:.0f}%\")\n",
"fig,ax=plt.subplots(1,2,figsize=(11,4))\n",
"ax[0].plot(range(1,T+1), ex.mean(0)); ax[0].axvline(16,ls='--',c='gray'); ax[0].set_xscale('log')\n",
"ax[0].set_xlabel('inference segments'); ax[0].set_ylabel('accuracy'); ax[0].set_title('accuracy vs compute')\n",
- "S=[(fail&(ex[:,:s].max(1)==0)).sum()/nf for s in range(16,T+1)]\n",
+ "S=[(fail&(ex[:,:s].max(1)==0)).sum()/max(nf,1) for s in range(16,T+1)]\n",
"ax[1].plot(range(16,T+1),S); ax[1].set_yscale('log'); ax[1].set_xlabel('segments'); ax[1].set_ylabel('frac failures still unsolved')\n",
"ax[1].set_title('escape from chaotic set (straight line on log-y = exponential escape)'); plt.tight_layout(); plt.show()"
]
},
{
"cell_type": "markdown",
- "id": "912eefb8",
+ "id": "306d11c9",
"metadata": {},
"source": [
"## 4. Basin accessibility — input-determined or initial-condition-determined?\n",
"\n",
- "The puzzle is re-injected at *every* segment (`z_H + input_embeddings`), so perturbing only the\n",
- "**initial** latent state `z0` is a clean initial-condition change that leaves the input intact.\n",
- "Restart each step-16 failure `K` times from `z0 + sigma*noise`: if a small kick frees it (some\n",
- "restart solves), the solution basin is large and accessible — *initial-condition-determined*; if no\n",
- "nearby IC escapes, the trapping is *input-determined*. TRM: a small kick frees most. HRM: a hard\n",
- "core escapes no nearby IC. (GPU: seconds. Laptop/CPU with TRM: a couple of minutes — lower `n`/`K`.)"
+ "The puzzle is re-injected every segment (`z_H + input_embeddings`), so perturbing only the\n",
+ "**initial** latent `z0` is a clean initial-condition change that leaves the input intact. Restart\n",
+ "each step-16 failure `K` times from `z0 + sigma*noise`: if a small kick frees it, the solution basin\n",
+ "is large and accessible (TRM); if no nearby IC escapes, the trapping is input-determined (HRM has a\n",
+ "hard core). (GPU: seconds. Laptop/CPU with TRM: a couple of minutes — lower `n`/`K`.)"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "b812488b",
+ "id": "682927dd",
"metadata": {},
"outputs": [],
"source": [
"def perturb_z0(inp, lab, pid, n=96, K=8, sigmas=(0.0, 0.1, 0.3, 1.0), n_seg=48, readout=16, seed=0):\n",
- " rng = np.random.default_rng(seed); idx = rng.choice(len(inp), n, replace=False)\n",
- " pe = inner.puzzle_emb_len; sf = inner.config.seq_len + pe; hid = inner.config.hidden_size\n",
- " is_hrm = hasattr(inner, \"H_level\") and getattr(inner, \"H_level\", None) is not None\n",
- " Hup = inner.H_level if is_hrm else inner.L_level # weight-tied TRM reuses L_level\n",
- " sc = float(inner.H_init.float().std()); g = torch.Generator(device=dev).manual_seed(seed)\n",
- " X = torch.tensor(inp[idx].astype(np.int32), device=dev); Y = torch.tensor(lab[idx].astype(np.int32), device=dev)\n",
- " P = torch.tensor(pid[idx].astype(np.int32), device=dev)\n",
- " si = dict(cos_sin=inner.rotary_emb() if hasattr(inner, \"rotary_emb\") else None)\n",
- " solve = np.zeros((n, len(sigmas), K), bool); base = None\n",
+ " rng=np.random.default_rng(seed); idx=rng.choice(len(inp), n, replace=False)\n",
+ " pe=inner.puzzle_emb_len; sf=inner.config.seq_len+pe; hid=inner.config.hidden_size\n",
+ " is_hrm=hasattr(inner,\"H_level\") and getattr(inner,\"H_level\",None) is not None\n",
+ " Hmod=inner.H_level if is_hrm else inner.L_level\n",
+ " sc=float(inner.H_init.float().std()); g=torch.Generator(device=dev).manual_seed(seed)\n",
+ " X=torch.tensor(inp[idx].astype(np.int32),device=dev); Y=torch.tensor(lab[idx].astype(np.int32),device=dev)\n",
+ " P=torch.tensor(pid[idx].astype(np.int32),device=dev)\n",
+ " si=dict(cos_sin=inner.rotary_emb() if hasattr(inner,\"rotary_emb\") else None)\n",
+ " solve=np.zeros((n,len(sigmas),K),bool); base=None\n",
" with torch.no_grad():\n",
- " emb0 = inner._input_embeddings(X, P); m0 = Y > 0\n",
- " for sj, sg in enumerate(sigmas):\n",
- " emb = emb0.repeat_interleave(K, 0); Yr = Y.repeat_interleave(K, 0); mr = m0.repeat_interleave(K, 0); B = n * K\n",
- " zH = inner.H_init.unsqueeze(0).expand(B, sf, hid).clone().to(inner.forward_dtype)\n",
- " zL = inner.L_init.unsqueeze(0).expand(B, sf, hid).clone().to(inner.forward_dtype)\n",
- " if sg > 0:\n",
- " zH = (zH.float() + sg*sc*torch.randn(zH.shape, generator=g, device=dev)).to(inner.forward_dtype)\n",
- " zL = (zL.float() + sg*sc*torch.randn(zL.shape, generator=g, device=dev)).to(inner.forward_dtype)\n",
- " solved = torch.zeros(B, dtype=torch.bool, device=dev)\n",
+ " emb0=inner._input_embeddings(X,P); m0=Y>0\n",
+ " for sj,sg in enumerate(tqdm(sigmas, desc=\"basin (sigma levels)\")):\n",
+ " emb=emb0.repeat_interleave(K,0); Yr=Y.repeat_interleave(K,0); mr=m0.repeat_interleave(K,0); B=n*K\n",
+ " zH=inner.H_init.unsqueeze(0).expand(B,sf,hid).clone().to(inner.forward_dtype)\n",
+ " zL=inner.L_init.unsqueeze(0).expand(B,sf,hid).clone().to(inner.forward_dtype)\n",
+ " if sg>0:\n",
+ " zH=(zH.float()+sg*sc*torch.randn(zH.shape,generator=g,device=dev)).to(inner.forward_dtype)\n",
+ " zL=(zL.float()+sg*sc*torch.randn(zL.shape,generator=g,device=dev)).to(inner.forward_dtype)\n",
+ " solved=torch.zeros(B,dtype=torch.bool,device=dev)\n",
" for s in range(n_seg):\n",
" for _h in range(inner.config.H_cycles):\n",
- " for _l in range(inner.config.L_cycles): zL = inner.L_level(zL, zH + emb, **si)\n",
- " zH = Hup(zH, zL, **si)\n",
- " ok = ((inner.lm_head(zH)[:, pe:].float().argmax(-1) == Yr) | ~mr).all(-1); solved |= ok\n",
- " if sj == 0 and s == readout - 1: base = ok.view(n, K)[:, 0].cpu().numpy()\n",
- " solve[:, sj] = solved.view(n, K).cpu().numpy()\n",
+ " for _l in range(inner.config.L_cycles): zL=inner.L_level(zL,zH+emb,**si)\n",
+ " zH=Hmod(zH,zL,**si)\n",
+ " ok=((inner.lm_head(zH)[:,pe:].float().argmax(-1)==Yr)|~mr).all(-1); solved|=ok\n",
+ " if sj==0 and s==readout-1: base=ok.view(n,K)[:,0].cpu().numpy()\n",
+ " solve[:,sj]=solved.view(n,K).cpu().numpy()\n",
" return solve, base, np.array(sigmas)\n",
"\n",
"solve, base, sg = perturb_z0(inp, lab, pid)\n",
- "fail = ~base; nf = int(fail.sum())\n",
- "print(f\"{nf} of {len(base)} puzzles fail@{16}; freeing them by restarting from a perturbed IC:\")\n",
- "for j, s in enumerate(sg):\n",
- " sub = solve[fail, j]; print(f\" sigma={s:.1f}: single-restart={sub.mean():.2f} best-of-K={sub.any(1).mean():.2f}\")\n",
- "plt.figure(figsize=(6, 4))\n",
- "plt.plot(sg, [solve[fail, j].mean() for j in range(len(sg))], 'o--', label='single restart')\n",
- "plt.plot(sg, [solve[fail, j].any(1).mean() for j in range(len(sg))], 's-', label='best-of-K')\n",
+ "fail=~base; nf=int(fail.sum())\n",
+ "print(f\"{nf} of {len(base)} puzzles fail@16; freeing them by restarting from a perturbed IC:\")\n",
+ "for j,s in enumerate(sg):\n",
+ " sub=solve[fail,j]; print(f\" sigma={s:.1f}: single-restart={sub.mean():.2f} best-of-K={sub.any(1).mean():.2f}\")\n",
+ "plt.figure(figsize=(6,4))\n",
+ "plt.plot(sg,[solve[fail,j].mean() for j in range(len(sg))],'o--',label='single restart')\n",
+ "plt.plot(sg,[solve[fail,j].any(1).mean() for j in range(len(sg))],'s-',label='best-of-K')\n",
"plt.xlabel('relative IC noise sigma'); plt.ylabel('solve rate (failing puzzles)')\n",
"plt.title('basin accessibility: does a restart free a trapped puzzle?'); plt.legend(); plt.grid(alpha=.3); plt.show()"
]
},
{
"cell_type": "markdown",
- "id": "4e2c8f69",
+ "id": "22ef7f6f",
"metadata": {},
"source": [
"## What this shows\n",
- "- **TRM**: step-16 failures *escape* the chaotic transient and resolve to the correct answer\n",
- " (≈0 settle to a wrong answer) → a chaotic **saddle** + one solution fixed point. More compute\n",
- " solves more puzzles.\n",
- "- **HRM**: failures escape too, but *much* more slowly — most are still churning at this horizon.\n",
- " Out to 4000 segments the never-correct fraction keeps decaying (≈0.87→0.77), so it is a\n",
- " **strongly-trapping chaotic saddle**, NOT a strict attractor. And the per-segment escape-rate gap\n",
- " (~5×) is mostly compute-per-segment: TRM evaluates its recurrent module 21×/segment vs HRM 6×, so\n",
- " per module-evaluation the gap is only ~1.6×.\n",
- "- **Neither settles to a wrong fixed point** — the \"spurious fixed point\" reading from 2D PCA is an\n",
- " artifact of projecting high-dimensional chaotic wandering onto two axes.\n",
+ "- **The result (cell 2):** in the same trained network, failed trajectories have a higher leading\n",
+ " finite-time Lyapunov exponent than successful ones — failure is locally more chaotic.\n",
+ "- **Why (cell 3):** that chaos is a *transient*. Failures sit on a chaotic **saddle**, not a wrong\n",
+ " fixed point — TRM's escape and self-correct with more compute; HRM's are much more strongly\n",
+ " trapped (still a saddle, just a far smaller escape rate). The per-segment escape gap is mostly\n",
+ " compute-per-segment (TRM evaluates its module 21×/segment vs HRM 6×; per module-eval the gap is\n",
+ " only ~1.6×). The \"spurious fixed point\" reading from 2D PCA is an artifact of projecting\n",
+ " high-dimensional chaotic wandering.\n",
+ "- **Basin (cell 4):** a small initial-condition kick frees most of TRM's trapped puzzles\n",
+ " (IC-determined, large basin); a hard core of HRM's escapes no nearby IC (input-determined).\n",
"\n",
- "Try: change `MODEL`, `N_SEG`, `eps` (toy); compare TRM vs HRM escape curves."
+ "Try: change `MODEL` (restart kernel), `n`/`n_seg`, and compare TRM vs HRM at every step."
]
}
],