summaryrefslogtreecommitdiff
path: root/notebooks/build_notebook.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-06-30 13:07:37 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-06-30 13:07:37 -0500
commit94bd896fb7c7bb3d6441ebec8887220532dbe690 (patch)
tree747cd0c0faea7a27c9233435c7ef02f24e416dc2 /notebooks/build_notebook.py
parentc6336c35d77974529de3aca966a366d52cacb8a4 (diff)
notebook: fix rollout drift-norm crash (.norm(1)->.norm(dim=1)); CPU warning; lighter FTLE defaultHEADmain
- extended_rollout: per-row drift norm was a scalar -> np.stack shape mismatch crash. fixed + verified. - setup: warn + Colab GPU instructions when no CUDA (CPU is very slow for the JVP cell). - FTLE default n=128->64 (still AUC~0.99) for faster runs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01MZBxRQ65wDxiUSm9Hr5Ere
Diffstat (limited to 'notebooks/build_notebook.py')
-rw-r--r--notebooks/build_notebook.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/notebooks/build_notebook.py b/notebooks/build_notebook.py
index 8a525b2..930dc68 100644
--- a/notebooks/build_notebook.py
+++ b/notebooks/build_notebook.py
@@ -35,7 +35,10 @@ md("## 0. Setup")
code("""%pip install -q torch einops pydantic huggingface_hub numpy matplotlib tqdm
import numpy as np, matplotlib.pyplot as plt, torch
from tqdm.auto import tqdm
-print("torch", torch.__version__, "| cuda", torch.cuda.is_available())""")
+print("torch", torch.__version__, "| cuda", torch.cuda.is_available())
+if not torch.cuda.is_available():
+ print("\\n⚠️ No GPU detected — the JVP/rollout cells will be SLOW on CPU.")
+ print(" Colab: Runtime → Change runtime type → Hardware accelerator → GPU (T4), then re-run.")""")
md(f"""## 1. Load a trained model from HuggingFace
@@ -126,7 +129,7 @@ def leading_ftle(inp, lab, pid, n=128, n_seg=16, seed=0):
ok=(((inner.lm_head(zH)[:,pe:].float().argmax(-1)==Y)|~m).all(-1)).cpu().numpy()
return ftle, ok
-ftle, succ = leading_ftle(inp, lab, pid, n=128)
+ftle, succ = leading_ftle(inp, lab, pid, n=64) # n=64 already separates cleanly; raise for tighter histograms
print(f"success rate {succ.mean():.2f} | median λ1 success {np.median(ftle[succ]):+.4f} vs failure {np.median(ftle[~succ]):+.4f}")
print(f"AUC(-λ1 -> success) = {auc(-ftle, succ.astype(int)):.3f} (>0.5 means failures are more chaotic)")
plt.figure(figsize=(6,4))
@@ -163,7 +166,7 @@ code("""def extended_rollout(inp, lab, pid, n=256, n_seg=128, seed=0):
zH=Hmod(zH, zL, **si)
p=inner.lm_head(zH)[:,pe:].float().argmax(-1)
EX.append(((p==Y)|~m).all(-1).float().cpu().numpy())
- DR.append((torch.zeros(n) if prev is None else (zH-prev).float().flatten(1).norm(1).cpu()).numpy())
+ DR.append((torch.zeros(n) if prev is None else (zH-prev).float().flatten(1).norm(dim=1).cpu()).numpy())
prev=zH.detach()
return np.stack(EX,1), np.stack(DR,1)