summaryrefslogtreecommitdiff
path: root/notebooks/build_notebook.py
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/build_notebook.py')
-rw-r--r--notebooks/build_notebook.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/notebooks/build_notebook.py b/notebooks/build_notebook.py
index 8a525b2..930dc68 100644
--- a/notebooks/build_notebook.py
+++ b/notebooks/build_notebook.py
@@ -35,7 +35,10 @@ md("## 0. Setup")
code("""%pip install -q torch einops pydantic huggingface_hub numpy matplotlib tqdm
import numpy as np, matplotlib.pyplot as plt, torch
from tqdm.auto import tqdm
-print("torch", torch.__version__, "| cuda", torch.cuda.is_available())""")
+print("torch", torch.__version__, "| cuda", torch.cuda.is_available())
+if not torch.cuda.is_available():
+ print("\\n⚠️ No GPU detected — the JVP/rollout cells will be SLOW on CPU.")
+ print(" Colab: Runtime → Change runtime type → Hardware accelerator → GPU (T4), then re-run.")""")
md(f"""## 1. Load a trained model from HuggingFace
@@ -126,7 +129,7 @@ def leading_ftle(inp, lab, pid, n=128, n_seg=16, seed=0):
ok=(((inner.lm_head(zH)[:,pe:].float().argmax(-1)==Y)|~m).all(-1)).cpu().numpy()
return ftle, ok
-ftle, succ = leading_ftle(inp, lab, pid, n=128)
+ftle, succ = leading_ftle(inp, lab, pid, n=64) # n=64 already separates cleanly; raise for tighter histograms
print(f"success rate {succ.mean():.2f} | median λ1 success {np.median(ftle[succ]):+.4f} vs failure {np.median(ftle[~succ]):+.4f}")
print(f"AUC(-λ1 -> success) = {auc(-ftle, succ.astype(int)):.3f} (>0.5 means failures are more chaotic)")
plt.figure(figsize=(6,4))
@@ -163,7 +166,7 @@ code("""def extended_rollout(inp, lab, pid, n=256, n_seg=128, seed=0):
zH=Hmod(zH, zL, **si)
p=inner.lm_head(zH)[:,pe:].float().argmax(-1)
EX.append(((p==Y)|~m).all(-1).float().cpu().numpy())
- DR.append((torch.zeros(n) if prev is None else (zH-prev).float().flatten(1).norm(1).cpu()).numpy())
+ DR.append((torch.zeros(n) if prev is None else (zH-prev).float().flatten(1).norm(dim=1).cpu()).numpy())
prev=zH.detach()
return np.stack(EX,1), np.stack(DR,1)