diff options
Diffstat (limited to 'notebooks/build_notebook.py')
| -rw-r--r-- | notebooks/build_notebook.py | 79 |
1 files changed, 67 insertions, 12 deletions
diff --git a/notebooks/build_notebook.py b/notebooks/build_notebook.py index c95c9f6..0eec83d 100644 --- a/notebooks/build_notebook.py +++ b/notebooks/build_notebook.py @@ -6,7 +6,7 @@ chaos vs HRM=trapped chaotic attractor. HF_REPO is filled in after the upload st import nbformat as nbf from pathlib import Path -HF_REPO = "YurenHao0426/recursive-reasoning-chaos" # filled after upload_to_hf.py +HF_REPO = "blackhao0426/recursive-reasoning-chaos" # HF account (GitHub is YurenHao0426) nb = nbf.v4.new_notebook() C = [] def md(t): C.append(nbf.v4.new_markdown_cell(t)) @@ -22,9 +22,13 @@ This notebook lets you reproduce and play with the mechanism: 1. **Toy model** (pure numpy, no GPU) — *transient chaos*: chaotic search of latent space until the trajectory escapes into the solution basin. Failures = not-yet-escaped trajectories. 2. **Real trained model** loaded from HuggingFace (`{HF_REPO}`). -3. **Extended rollout** — run the recurrence far beyond its training budget. **TRM** failures escape - (transient chaotic saddle → the model solves 96%+ given enough compute); **HRM** failures stay - trapped (a chaotic *attractor*). Neither settles to a wrong *fixed point*. +3. **Extended rollout** — run the recurrence far beyond its training budget. Both architectures' + failures sit on a chaotic *saddle* (transient chaos), not a wrong fixed point — they just escape + at very different rates: **TRM** failures mostly escape and self-correct given enough compute; + **HRM** failures are far more strongly trapped (most keep churning). +4. **Basin accessibility** — restart a trapped puzzle from perturbed initial latent states. A small + kick frees most of TRM's (IC-determined, large basin); a hard core of HRM's never escapes any + nearby initial condition (input-determined). Companion analysis repo: `github.com/YurenHao0426/recursive-reasoning-dynamics`.""") @@ -140,16 +144,67 @@ ax[0].plot(range(1,T+1), ex.mean(0)); ax[0].axvline(16,ls='--',c='gray'); ax[0]. ax[0].set_xlabel('inference segments'); ax[0].set_ylabel('accuracy'); ax[0].set_title('accuracy vs compute') S=[(fail&(ex[:,:s].max(1)==0)).sum()/nf for s in range(16,T+1)] ax[1].plot(range(16,T+1),S); ax[1].set_yscale('log'); ax[1].set_xlabel('segments'); ax[1].set_ylabel('frac failures still unsolved') -ax[1].set_title('escape from chaotic set (straight=transient, plateau=attractor)'); plt.tight_layout(); plt.show()""") +ax[1].set_title('escape from chaotic set (straight line on log-y = exponential escape)'); plt.tight_layout(); plt.show()""") + +md("""## 4. Basin accessibility — input-determined or initial-condition-determined? + +The puzzle is re-injected at *every* segment (`z_H + input_embeddings`), so perturbing only the +**initial** latent state `z0` is a clean initial-condition change that leaves the input intact. +Restart each step-16 failure `K` times from `z0 + sigma*noise`: if a small kick frees it (some +restart solves), the solution basin is large and accessible — *initial-condition-determined*; if no +nearby IC escapes, the trapping is *input-determined*. TRM: a small kick frees most. HRM: a hard +core escapes no nearby IC. (GPU: seconds. Laptop/CPU with TRM: a couple of minutes — lower `n`/`K`.)""") +code("""def perturb_z0(inp, lab, pid, n=96, K=8, sigmas=(0.0, 0.1, 0.3, 1.0), n_seg=48, readout=16, seed=0): + rng = np.random.default_rng(seed); idx = rng.choice(len(inp), n, replace=False) + pe = inner.puzzle_emb_len; sf = inner.config.seq_len + pe; hid = inner.config.hidden_size + is_hrm = hasattr(inner, "H_level") and getattr(inner, "H_level", None) is not None + Hup = inner.H_level if is_hrm else inner.L_level # weight-tied TRM reuses L_level + sc = float(inner.H_init.float().std()); g = torch.Generator(device=dev).manual_seed(seed) + X = torch.tensor(inp[idx].astype(np.int32), device=dev); Y = torch.tensor(lab[idx].astype(np.int32), device=dev) + P = torch.tensor(pid[idx].astype(np.int32), device=dev) + si = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None) + solve = np.zeros((n, len(sigmas), K), bool); base = None + with torch.no_grad(): + emb0 = inner._input_embeddings(X, P); m0 = Y > 0 + for sj, sg in enumerate(sigmas): + emb = emb0.repeat_interleave(K, 0); Yr = Y.repeat_interleave(K, 0); mr = m0.repeat_interleave(K, 0); B = n * K + zH = inner.H_init.unsqueeze(0).expand(B, sf, hid).clone().to(inner.forward_dtype) + zL = inner.L_init.unsqueeze(0).expand(B, sf, hid).clone().to(inner.forward_dtype) + if sg > 0: + zH = (zH.float() + sg*sc*torch.randn(zH.shape, generator=g, device=dev)).to(inner.forward_dtype) + zL = (zL.float() + sg*sc*torch.randn(zL.shape, generator=g, device=dev)).to(inner.forward_dtype) + solved = torch.zeros(B, dtype=torch.bool, device=dev) + for s in range(n_seg): + for _h in range(inner.config.H_cycles): + for _l in range(inner.config.L_cycles): zL = inner.L_level(zL, zH + emb, **si) + zH = Hup(zH, zL, **si) + ok = ((inner.lm_head(zH)[:, pe:].float().argmax(-1) == Yr) | ~mr).all(-1); solved |= ok + if sj == 0 and s == readout - 1: base = ok.view(n, K)[:, 0].cpu().numpy() + solve[:, sj] = solved.view(n, K).cpu().numpy() + return solve, base, np.array(sigmas) + +solve, base, sg = perturb_z0(inp, lab, pid) +fail = ~base; nf = int(fail.sum()) +print(f"{nf} of {len(base)} puzzles fail@{16}; freeing them by restarting from a perturbed IC:") +for j, s in enumerate(sg): + sub = solve[fail, j]; print(f" sigma={s:.1f}: single-restart={sub.mean():.2f} best-of-K={sub.any(1).mean():.2f}") +plt.figure(figsize=(6, 4)) +plt.plot(sg, [solve[fail, j].mean() for j in range(len(sg))], 'o--', label='single restart') +plt.plot(sg, [solve[fail, j].any(1).mean() for j in range(len(sg))], 's-', label='best-of-K') +plt.xlabel('relative IC noise sigma'); plt.ylabel('solve rate (failing puzzles)') +plt.title('basin accessibility: does a restart free a trapped puzzle?'); plt.legend(); plt.grid(alpha=.3); plt.show()""") md("""## What this shows -- **TRM**: accuracy keeps climbing with compute; step-16 failures *escape* the chaotic transient and - resolve to the correct answer (≈0 settle to a wrong answer). → a chaotic **saddle** + one solution - fixed point. *More inference compute solves more puzzles.* -- **HRM**: accuracy plateaus; failures stay **trapped** (latent keeps churning, never escapes). → - bistability between a stable fixed point (success) and a chaotic **attractor** (failure). -- Neither settles to a *wrong fixed point* — the "spurious fixed point" reading from 2D PCA is an - artifact of projecting high-dimensional chaotic wandering. +- **TRM**: step-16 failures *escape* the chaotic transient and resolve to the correct answer + (≈0 settle to a wrong answer) → a chaotic **saddle** + one solution fixed point. More compute + solves more puzzles. +- **HRM**: failures escape too, but *much* more slowly — most are still churning at this horizon. + Out to 4000 segments the never-correct fraction keeps decaying (≈0.87→0.77), so it is a + **strongly-trapping chaotic saddle**, NOT a strict attractor. And the per-segment escape-rate gap + (~5×) is mostly compute-per-segment: TRM evaluates its recurrent module 21×/segment vs HRM 6×, so + per module-evaluation the gap is only ~1.6×. +- **Neither settles to a wrong fixed point** — the "spurious fixed point" reading from 2D PCA is an + artifact of projecting high-dimensional chaotic wandering onto two axes. Try: change `MODEL`, `N_SEG`, `eps` (toy); compare TRM vs HRM escape curves.""") |
