summaryrefslogtreecommitdiff
path: root/extend_rollout.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-29 12:15:51 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-29 12:15:51 -0500
commita6ec4288a2232988b130b2f00bb2565f81706966 (patch)
tree1bb86e7f0b899b823b9e7fdf383e832d30a181e0 /extend_rollout.py
Recursive reasoning dynamics: analysis pipeline, paper drafts, toy models
Failure=more-chaotic (task-general under validity labeling) reduces to convergence/completeness detection; mechanism (transient chaos vs multistability vs input-induced) under investigation. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'extend_rollout.py')
-rw-r--r--extend_rollout.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/extend_rollout.py b/extend_rollout.py
new file mode 100644
index 0000000..09af1f8
--- /dev/null
+++ b/extend_rollout.py
@@ -0,0 +1,68 @@
+"""Discriminate Rainer's hypotheses: run the trained recurrence FAR beyond the 16-segment budget
+and watch the fate of trajectories that FAIL at segment 16.
+ - settle to CORRECT later => transient that would self-resolve (more compute helps)
+ - settle to WRONG (drift->0) => multistable WRONG attractor (genuine bistability)
+ - never settle (drift stays high) => chaotic saddle / persistent non-convergence
+Plain forward (no JVP). Saves per-segment decoded-exactness and per-segment z_H drift.
+"""
+from __future__ import annotations
+import sys, argparse
+from pathlib import Path
+import numpy as np
+import torch
+
+sys.path.insert(0, "/home/yurenh2/rrm/research/flossing")
+from diagnose_trm_joint_maze import load_model, load_test_samples
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--ckpt-root", required=True)
+ ap.add_argument("--ckpt-name", required=True)
+ ap.add_argument("--data", required=True)
+ ap.add_argument("--n", type=int, default=512)
+ ap.add_argument("--batch-size", type=int, default=32)
+ ap.add_argument("--n-seg", type=int, default=128)
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--out", required=True)
+ args = ap.parse_args()
+ device = "cuda"
+ model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device)
+ inner = model.inner
+ test = load_test_samples(Path(args.data), args.n, 0, 1, args.seed)
+ n = len(test["inputs"]); pe = inner.puzzle_emb_len
+
+ EX, DR, IDX = [], [], []
+ for s in range(0, n, args.batch_size):
+ e = min(s + args.batch_size, n)
+ batch = {k: test[k][s:e].to(device) for k in ["inputs", "labels", "puzzle_identifiers"]}
+ B = batch["inputs"].shape[0]
+ seq_full = inner.config.seq_len + pe; hidden = inner.config.hidden_size
+ with torch.no_grad():
+ z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
+ z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
+ seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None)
+ inp_emb = inner._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
+ labels = batch["labels"]; mask = labels > 0
+ prev_zH = None; ex_seg, dr_seg = [], []
+ for seg in range(args.n_seg):
+ for _h in range(inner.config.H_cycles):
+ for _l in range(inner.config.L_cycles):
+ z_L = inner.L_level(z_L, z_H + inp_emb, **seq_info)
+ z_H = inner.L_level(z_H, z_L, **seq_info)
+ p = inner.lm_head(z_H)[:, pe:].float().argmax(-1)
+ ex_seg.append(((p == labels) | ~mask).all(-1).float().cpu())
+ dr_seg.append((torch.zeros(B) if prev_zH is None
+ else (z_H - prev_zH).float().flatten(1).norm(dim=1).cpu()))
+ prev_zH = z_H.detach()
+ EX.append(torch.stack(ex_seg, 1).numpy()); DR.append(torch.stack(dr_seg, 1).numpy())
+ IDX.append(test["idx"][s:e])
+ print(f" [{e}/{n}] exact@16={torch.stack(ex_seg,1)[:,15].mean():.3f} exact@{args.n_seg}={torch.stack(ex_seg,1)[:,-1].mean():.3f}", flush=True)
+
+ np.savez_compressed(args.out, exact_seg=np.concatenate(EX), drift_seg=np.concatenate(DR),
+ idx=np.concatenate(IDX))
+ print("saved", args.out)
+
+
+if __name__ == "__main__":
+ main()