summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--notebooks/build_notebook.py79
-rw-r--r--notebooks/recursive_reasoning_chaos.ipynb118
-rw-r--r--notebooks/upload_to_hf.py2
3 files changed, 162 insertions, 37 deletions
diff --git a/notebooks/build_notebook.py b/notebooks/build_notebook.py
index c95c9f6..0eec83d 100644
--- a/notebooks/build_notebook.py
+++ b/notebooks/build_notebook.py
@@ -6,7 +6,7 @@ chaos vs HRM=trapped chaotic attractor. HF_REPO is filled in after the upload st
import nbformat as nbf
from pathlib import Path
-HF_REPO = "YurenHao0426/recursive-reasoning-chaos" # filled after upload_to_hf.py
+HF_REPO = "blackhao0426/recursive-reasoning-chaos" # HF account (GitHub is YurenHao0426)
nb = nbf.v4.new_notebook()
C = []
def md(t): C.append(nbf.v4.new_markdown_cell(t))
@@ -22,9 +22,13 @@ This notebook lets you reproduce and play with the mechanism:
1. **Toy model** (pure numpy, no GPU) — *transient chaos*: chaotic search of latent space until the
trajectory escapes into the solution basin. Failures = not-yet-escaped trajectories.
2. **Real trained model** loaded from HuggingFace (`{HF_REPO}`).
-3. **Extended rollout** — run the recurrence far beyond its training budget. **TRM** failures escape
- (transient chaotic saddle → the model solves 96%+ given enough compute); **HRM** failures stay
- trapped (a chaotic *attractor*). Neither settles to a wrong *fixed point*.
+3. **Extended rollout** — run the recurrence far beyond its training budget. Both architectures'
+ failures sit on a chaotic *saddle* (transient chaos), not a wrong fixed point — they just escape
+ at very different rates: **TRM** failures mostly escape and self-correct given enough compute;
+ **HRM** failures are far more strongly trapped (most keep churning).
+4. **Basin accessibility** — restart a trapped puzzle from perturbed initial latent states. A small
+ kick frees most of TRM's (IC-determined, large basin); a hard core of HRM's never escapes any
+ nearby initial condition (input-determined).
Companion analysis repo: `github.com/YurenHao0426/recursive-reasoning-dynamics`.""")
@@ -140,16 +144,67 @@ ax[0].plot(range(1,T+1), ex.mean(0)); ax[0].axvline(16,ls='--',c='gray'); ax[0].
ax[0].set_xlabel('inference segments'); ax[0].set_ylabel('accuracy'); ax[0].set_title('accuracy vs compute')
S=[(fail&(ex[:,:s].max(1)==0)).sum()/nf for s in range(16,T+1)]
ax[1].plot(range(16,T+1),S); ax[1].set_yscale('log'); ax[1].set_xlabel('segments'); ax[1].set_ylabel('frac failures still unsolved')
-ax[1].set_title('escape from chaotic set (straight=transient, plateau=attractor)'); plt.tight_layout(); plt.show()""")
+ax[1].set_title('escape from chaotic set (straight line on log-y = exponential escape)'); plt.tight_layout(); plt.show()""")
+
+md("""## 4. Basin accessibility — input-determined or initial-condition-determined?
+
+The puzzle is re-injected at *every* segment (`z_H + input_embeddings`), so perturbing only the
+**initial** latent state `z0` is a clean initial-condition change that leaves the input intact.
+Restart each step-16 failure `K` times from `z0 + sigma*noise`: if a small kick frees it (some
+restart solves), the solution basin is large and accessible — *initial-condition-determined*; if no
+nearby IC escapes, the trapping is *input-determined*. TRM: a small kick frees most. HRM: a hard
+core escapes no nearby IC. (GPU: seconds. Laptop/CPU with TRM: a couple of minutes — lower `n`/`K`.)""")
+code("""def perturb_z0(inp, lab, pid, n=96, K=8, sigmas=(0.0, 0.1, 0.3, 1.0), n_seg=48, readout=16, seed=0):
+ rng = np.random.default_rng(seed); idx = rng.choice(len(inp), n, replace=False)
+ pe = inner.puzzle_emb_len; sf = inner.config.seq_len + pe; hid = inner.config.hidden_size
+ is_hrm = hasattr(inner, "H_level") and getattr(inner, "H_level", None) is not None
+ Hup = inner.H_level if is_hrm else inner.L_level # weight-tied TRM reuses L_level
+ sc = float(inner.H_init.float().std()); g = torch.Generator(device=dev).manual_seed(seed)
+ X = torch.tensor(inp[idx].astype(np.int32), device=dev); Y = torch.tensor(lab[idx].astype(np.int32), device=dev)
+ P = torch.tensor(pid[idx].astype(np.int32), device=dev)
+ si = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None)
+ solve = np.zeros((n, len(sigmas), K), bool); base = None
+ with torch.no_grad():
+ emb0 = inner._input_embeddings(X, P); m0 = Y > 0
+ for sj, sg in enumerate(sigmas):
+ emb = emb0.repeat_interleave(K, 0); Yr = Y.repeat_interleave(K, 0); mr = m0.repeat_interleave(K, 0); B = n * K
+ zH = inner.H_init.unsqueeze(0).expand(B, sf, hid).clone().to(inner.forward_dtype)
+ zL = inner.L_init.unsqueeze(0).expand(B, sf, hid).clone().to(inner.forward_dtype)
+ if sg > 0:
+ zH = (zH.float() + sg*sc*torch.randn(zH.shape, generator=g, device=dev)).to(inner.forward_dtype)
+ zL = (zL.float() + sg*sc*torch.randn(zL.shape, generator=g, device=dev)).to(inner.forward_dtype)
+ solved = torch.zeros(B, dtype=torch.bool, device=dev)
+ for s in range(n_seg):
+ for _h in range(inner.config.H_cycles):
+ for _l in range(inner.config.L_cycles): zL = inner.L_level(zL, zH + emb, **si)
+ zH = Hup(zH, zL, **si)
+ ok = ((inner.lm_head(zH)[:, pe:].float().argmax(-1) == Yr) | ~mr).all(-1); solved |= ok
+ if sj == 0 and s == readout - 1: base = ok.view(n, K)[:, 0].cpu().numpy()
+ solve[:, sj] = solved.view(n, K).cpu().numpy()
+ return solve, base, np.array(sigmas)
+
+solve, base, sg = perturb_z0(inp, lab, pid)
+fail = ~base; nf = int(fail.sum())
+print(f"{nf} of {len(base)} puzzles fail@{16}; freeing them by restarting from a perturbed IC:")
+for j, s in enumerate(sg):
+ sub = solve[fail, j]; print(f" sigma={s:.1f}: single-restart={sub.mean():.2f} best-of-K={sub.any(1).mean():.2f}")
+plt.figure(figsize=(6, 4))
+plt.plot(sg, [solve[fail, j].mean() for j in range(len(sg))], 'o--', label='single restart')
+plt.plot(sg, [solve[fail, j].any(1).mean() for j in range(len(sg))], 's-', label='best-of-K')
+plt.xlabel('relative IC noise sigma'); plt.ylabel('solve rate (failing puzzles)')
+plt.title('basin accessibility: does a restart free a trapped puzzle?'); plt.legend(); plt.grid(alpha=.3); plt.show()""")
md("""## What this shows
-- **TRM**: accuracy keeps climbing with compute; step-16 failures *escape* the chaotic transient and
- resolve to the correct answer (≈0 settle to a wrong answer). → a chaotic **saddle** + one solution
- fixed point. *More inference compute solves more puzzles.*
-- **HRM**: accuracy plateaus; failures stay **trapped** (latent keeps churning, never escapes). →
- bistability between a stable fixed point (success) and a chaotic **attractor** (failure).
-- Neither settles to a *wrong fixed point* — the "spurious fixed point" reading from 2D PCA is an
- artifact of projecting high-dimensional chaotic wandering.
+- **TRM**: step-16 failures *escape* the chaotic transient and resolve to the correct answer
+ (≈0 settle to a wrong answer) → a chaotic **saddle** + one solution fixed point. More compute
+ solves more puzzles.
+- **HRM**: failures escape too, but *much* more slowly — most are still churning at this horizon.
+ Out to 4000 segments the never-correct fraction keeps decaying (≈0.87→0.77), so it is a
+ **strongly-trapping chaotic saddle**, NOT a strict attractor. And the per-segment escape-rate gap
+ (~5×) is mostly compute-per-segment: TRM evaluates its recurrent module 21×/segment vs HRM 6×, so
+ per module-evaluation the gap is only ~1.6×.
+- **Neither settles to a wrong fixed point** — the "spurious fixed point" reading from 2D PCA is an
+ artifact of projecting high-dimensional chaotic wandering onto two axes.
Try: change `MODEL`, `N_SEG`, `eps` (toy); compare TRM vs HRM escape curves.""")
diff --git a/notebooks/recursive_reasoning_chaos.ipynb b/notebooks/recursive_reasoning_chaos.ipynb
index 0751ae3..097f9b6 100644
--- a/notebooks/recursive_reasoning_chaos.ipynb
+++ b/notebooks/recursive_reasoning_chaos.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "0347ea73",
+ "id": "6c32c5e8",
"metadata": {},
"source": [
"# Recursive Reasoning Failures are Chaotic — and it's *transient chaos*\n",
@@ -14,17 +14,21 @@
"This notebook lets you reproduce and play with the mechanism:\n",
"1. **Toy model** (pure numpy, no GPU) — *transient chaos*: chaotic search of latent space until the\n",
" trajectory escapes into the solution basin. Failures = not-yet-escaped trajectories.\n",
- "2. **Real trained model** loaded from HuggingFace (`YurenHao0426/recursive-reasoning-chaos`).\n",
- "3. **Extended rollout** — run the recurrence far beyond its training budget. **TRM** failures escape\n",
- " (transient chaotic saddle → the model solves 96%+ given enough compute); **HRM** failures stay\n",
- " trapped (a chaotic *attractor*). Neither settles to a wrong *fixed point*.\n",
+ "2. **Real trained model** loaded from HuggingFace (`blackhao0426/recursive-reasoning-chaos`).\n",
+ "3. **Extended rollout** — run the recurrence far beyond its training budget. Both architectures'\n",
+ " failures sit on a chaotic *saddle* (transient chaos), not a wrong fixed point — they just escape\n",
+ " at very different rates: **TRM** failures mostly escape and self-correct given enough compute;\n",
+ " **HRM** failures are far more strongly trapped (most keep churning).\n",
+ "4. **Basin accessibility** — restart a trapped puzzle from perturbed initial latent states. A small\n",
+ " kick frees most of TRM's (IC-determined, large basin); a hard core of HRM's never escapes any\n",
+ " nearby initial condition (input-determined).\n",
"\n",
"Companion analysis repo: `github.com/YurenHao0426/recursive-reasoning-dynamics`."
]
},
{
"cell_type": "markdown",
- "id": "23f53bbd",
+ "id": "9161f5cf",
"metadata": {},
"source": [
"## 0. Setup"
@@ -33,7 +37,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "40edaba4",
+ "id": "2034b179",
"metadata": {},
"outputs": [],
"source": [
@@ -46,7 +50,7 @@
},
{
"cell_type": "markdown",
- "id": "7c7960ce",
+ "id": "fe43cb07",
"metadata": {},
"source": [
"## 1. The toy model — transient chaos (no GPU, runs in seconds)\n",
@@ -61,7 +65,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e608a245",
+ "id": "6593c881",
"metadata": {},
"outputs": [],
"source": [
@@ -97,18 +101,18 @@
},
{
"cell_type": "markdown",
- "id": "8f3f5192",
+ "id": "aee64679",
"metadata": {},
"source": [
"## 2. Load a trained model from HuggingFace\n",
"\n",
- "Downloads the model code + checkpoint + config from `YurenHao0426/recursive-reasoning-chaos`. `MODEL` ∈ {`trm_sudoku`, `hrm_sudoku`}."
+ "Downloads the model code + checkpoint + config from `blackhao0426/recursive-reasoning-chaos`. `MODEL` ∈ {`trm_sudoku`, `hrm_sudoku`}."
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "f8972bdd",
+ "id": "5f5f69ff",
"metadata": {},
"outputs": [],
"source": [
@@ -116,7 +120,7 @@
"from pathlib import Path\n",
"from huggingface_hub import snapshot_download\n",
"\n",
- "HF_REPO = \"YurenHao0426/recursive-reasoning-chaos\"\n",
+ "HF_REPO = \"blackhao0426/recursive-reasoning-chaos\"\n",
"MODEL = \"trm_sudoku\" # or \"hrm_sudoku\"\n",
"root = Path(snapshot_download(HF_REPO))\n",
"# TRM and HRM ship separate `models/` packages -> put the right one on the path.\n",
@@ -143,7 +147,7 @@
},
{
"cell_type": "markdown",
- "id": "7d9667d2",
+ "id": "cea384ce",
"metadata": {},
"source": [
"## 3. Extended rollout — the mechanism\n",
@@ -155,7 +159,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5143695e",
+ "id": "5d7ec0ce",
"metadata": {},
"outputs": [],
"source": [
@@ -193,22 +197,88 @@
"ax[0].set_xlabel('inference segments'); ax[0].set_ylabel('accuracy'); ax[0].set_title('accuracy vs compute')\n",
"S=[(fail&(ex[:,:s].max(1)==0)).sum()/nf for s in range(16,T+1)]\n",
"ax[1].plot(range(16,T+1),S); ax[1].set_yscale('log'); ax[1].set_xlabel('segments'); ax[1].set_ylabel('frac failures still unsolved')\n",
- "ax[1].set_title('escape from chaotic set (straight=transient, plateau=attractor)'); plt.tight_layout(); plt.show()"
+ "ax[1].set_title('escape from chaotic set (straight line on log-y = exponential escape)'); plt.tight_layout(); plt.show()"
]
},
{
"cell_type": "markdown",
- "id": "00681eb1",
+ "id": "912eefb8",
+ "metadata": {},
+ "source": [
+ "## 4. Basin accessibility — input-determined or initial-condition-determined?\n",
+ "\n",
+ "The puzzle is re-injected at *every* segment (`z_H + input_embeddings`), so perturbing only the\n",
+ "**initial** latent state `z0` is a clean initial-condition change that leaves the input intact.\n",
+ "Restart each step-16 failure `K` times from `z0 + sigma*noise`: if a small kick frees it (some\n",
+ "restart solves), the solution basin is large and accessible — *initial-condition-determined*; if no\n",
+ "nearby IC escapes, the trapping is *input-determined*. TRM: a small kick frees most. HRM: a hard\n",
+ "core escapes no nearby IC. (GPU: seconds. Laptop/CPU with TRM: a couple of minutes — lower `n`/`K`.)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b812488b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def perturb_z0(inp, lab, pid, n=96, K=8, sigmas=(0.0, 0.1, 0.3, 1.0), n_seg=48, readout=16, seed=0):\n",
+ " rng = np.random.default_rng(seed); idx = rng.choice(len(inp), n, replace=False)\n",
+ " pe = inner.puzzle_emb_len; sf = inner.config.seq_len + pe; hid = inner.config.hidden_size\n",
+ " is_hrm = hasattr(inner, \"H_level\") and getattr(inner, \"H_level\", None) is not None\n",
+ " Hup = inner.H_level if is_hrm else inner.L_level # weight-tied TRM reuses L_level\n",
+ " sc = float(inner.H_init.float().std()); g = torch.Generator(device=dev).manual_seed(seed)\n",
+ " X = torch.tensor(inp[idx].astype(np.int32), device=dev); Y = torch.tensor(lab[idx].astype(np.int32), device=dev)\n",
+ " P = torch.tensor(pid[idx].astype(np.int32), device=dev)\n",
+ " si = dict(cos_sin=inner.rotary_emb() if hasattr(inner, \"rotary_emb\") else None)\n",
+ " solve = np.zeros((n, len(sigmas), K), bool); base = None\n",
+ " with torch.no_grad():\n",
+ " emb0 = inner._input_embeddings(X, P); m0 = Y > 0\n",
+ " for sj, sg in enumerate(sigmas):\n",
+ " emb = emb0.repeat_interleave(K, 0); Yr = Y.repeat_interleave(K, 0); mr = m0.repeat_interleave(K, 0); B = n * K\n",
+ " zH = inner.H_init.unsqueeze(0).expand(B, sf, hid).clone().to(inner.forward_dtype)\n",
+ " zL = inner.L_init.unsqueeze(0).expand(B, sf, hid).clone().to(inner.forward_dtype)\n",
+ " if sg > 0:\n",
+ " zH = (zH.float() + sg*sc*torch.randn(zH.shape, generator=g, device=dev)).to(inner.forward_dtype)\n",
+ " zL = (zL.float() + sg*sc*torch.randn(zL.shape, generator=g, device=dev)).to(inner.forward_dtype)\n",
+ " solved = torch.zeros(B, dtype=torch.bool, device=dev)\n",
+ " for s in range(n_seg):\n",
+ " for _h in range(inner.config.H_cycles):\n",
+ " for _l in range(inner.config.L_cycles): zL = inner.L_level(zL, zH + emb, **si)\n",
+ " zH = Hup(zH, zL, **si)\n",
+ " ok = ((inner.lm_head(zH)[:, pe:].float().argmax(-1) == Yr) | ~mr).all(-1); solved |= ok\n",
+ " if sj == 0 and s == readout - 1: base = ok.view(n, K)[:, 0].cpu().numpy()\n",
+ " solve[:, sj] = solved.view(n, K).cpu().numpy()\n",
+ " return solve, base, np.array(sigmas)\n",
+ "\n",
+ "solve, base, sg = perturb_z0(inp, lab, pid)\n",
+ "fail = ~base; nf = int(fail.sum())\n",
+ "print(f\"{nf} of {len(base)} puzzles fail@{16}; freeing them by restarting from a perturbed IC:\")\n",
+ "for j, s in enumerate(sg):\n",
+ " sub = solve[fail, j]; print(f\" sigma={s:.1f}: single-restart={sub.mean():.2f} best-of-K={sub.any(1).mean():.2f}\")\n",
+ "plt.figure(figsize=(6, 4))\n",
+ "plt.plot(sg, [solve[fail, j].mean() for j in range(len(sg))], 'o--', label='single restart')\n",
+ "plt.plot(sg, [solve[fail, j].any(1).mean() for j in range(len(sg))], 's-', label='best-of-K')\n",
+ "plt.xlabel('relative IC noise sigma'); plt.ylabel('solve rate (failing puzzles)')\n",
+ "plt.title('basin accessibility: does a restart free a trapped puzzle?'); plt.legend(); plt.grid(alpha=.3); plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4e2c8f69",
"metadata": {},
"source": [
"## What this shows\n",
- "- **TRM**: accuracy keeps climbing with compute; step-16 failures *escape* the chaotic transient and\n",
- " resolve to the correct answer (≈0 settle to a wrong answer). → a chaotic **saddle** + one solution\n",
- " fixed point. *More inference compute solves more puzzles.*\n",
- "- **HRM**: accuracy plateaus; failures stay **trapped** (latent keeps churning, never escapes). →\n",
- " bistability between a stable fixed point (success) and a chaotic **attractor** (failure).\n",
- "- Neither settles to a *wrong fixed point* — the \"spurious fixed point\" reading from 2D PCA is an\n",
- " artifact of projecting high-dimensional chaotic wandering.\n",
+ "- **TRM**: step-16 failures *escape* the chaotic transient and resolve to the correct answer\n",
+ " (≈0 settle to a wrong answer) → a chaotic **saddle** + one solution fixed point. More compute\n",
+ " solves more puzzles.\n",
+ "- **HRM**: failures escape too, but *much* more slowly — most are still churning at this horizon.\n",
+ " Out to 4000 segments the never-correct fraction keeps decaying (≈0.87→0.77), so it is a\n",
+ " **strongly-trapping chaotic saddle**, NOT a strict attractor. And the per-segment escape-rate gap\n",
+ " (~5×) is mostly compute-per-segment: TRM evaluates its recurrent module 21×/segment vs HRM 6×, so\n",
+ " per module-evaluation the gap is only ~1.6×.\n",
+ "- **Neither settles to a wrong fixed point** — the \"spurious fixed point\" reading from 2D PCA is an\n",
+ " artifact of projecting high-dimensional chaotic wandering onto two axes.\n",
"\n",
"Try: change `MODEL`, `N_SEG`, `eps` (toy); compare TRM vs HRM escape curves."
]
diff --git a/notebooks/upload_to_hf.py b/notebooks/upload_to_hf.py
index dfdd87e..4b83747 100644
--- a/notebooks/upload_to_hf.py
+++ b/notebooks/upload_to_hf.py
@@ -7,7 +7,7 @@ from pathlib import Path
import numpy as np
from huggingface_hub import HfApi, create_repo
-HF_REPO = "YurenHao0426/recursive-reasoning-chaos"
+HF_REPO = "blackhao0426/recursive-reasoning-chaos" # HF account (GitHub is YurenHao0426)
RRM = Path("/home/yurenh2/rrm")
TRM_CK = RRM / "trm/checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_mlp_t_sudoku_official_gbs768_repro"
HRM_CK = RRM / "hrm/checkpoints/Sudoku-extreme-1k-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 righteous-python"