diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-29 12:15:51 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-29 12:15:51 -0500 |
| commit | a6ec4288a2232988b130b2f00bb2565f81706966 (patch) | |
| tree | 1bb86e7f0b899b823b9e7fdf383e832d30a181e0 /maze_pred_dump.py | |
Recursive reasoning dynamics: analysis pipeline, paper drafts, toy models
Failure=more-chaotic (task-general under validity labeling) reduces to convergence/completeness
detection; mechanism (transient chaos vs multistability vs input-induced) under investigation.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'maze_pred_dump.py')
| -rw-r--r-- | maze_pred_dump.py | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/maze_pred_dump.py b/maze_pred_dump.py new file mode 100644 index 0000000..ebbc735 --- /dev/null +++ b/maze_pred_dump.py @@ -0,0 +1,85 @@ +"""Dump per-cell predictions for a TRM-Maze checkpoint (plain forward, no JVP) so we can +analyze WHERE failure errors are (connected detour = coherent stable wrong path vs scattered). +Saves preds, labels, inputs, exact_correct, idx for n test puzzles. +""" +from __future__ import annotations +import sys, argparse +from pathlib import Path +import numpy as np +import torch + +sys.path.insert(0, "/home/yurenh2/rrm/research/flossing") +from diagnose_trm_joint_maze import load_model, load_test_samples # att+maze-capable loader + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--ckpt-root", required=True) + ap.add_argument("--ckpt-name", default="step_130200") + ap.add_argument("--data", required=True) + ap.add_argument("--n", type=int, default=512) + ap.add_argument("--batch-size", type=int, default=32) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--out", required=True) + args = ap.parse_args() + device = "cuda" + model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device) + inner = model.inner + test = load_test_samples(Path(args.data), args.n, 0, 1, args.seed) + n = len(test["inputs"]) + pe = inner.puzzle_emb_len + + preds_all, labels_all, inputs_all, exact_all, idx_all = [], [], [], [], [] + ans_drift_full_all, ans_drift_ans_all, ldrift_all = [], [], [] + for s in range(0, n, args.batch_size): + e = min(s + args.batch_size, n) + batch = {k: test[k][s:e].to(device) for k in ["inputs", "labels", "puzzle_identifiers"]} + B = batch["inputs"].shape[0] + seq_full = inner.config.seq_len + pe + hidden = inner.config.hidden_size + with torch.no_grad(): + z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None) + inp_emb = inner._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) + labels = batch["labels"] + ans_mask = (labels != batch["inputs"]) # SOLUTION-SPACE cells: where task requires change + prev = None; prev_zH = None + adrift_full, adrift_ans, ldrift = [], [], [] # answer Hamming drift + LATENT z_H drift + for _ in range(inner.config.halt_max_steps): + for _h in range(inner.config.H_cycles): + for _l in range(inner.config.L_cycles): + z_L = inner.L_level(z_L, z_H + inp_emb, **seq_info) + z_H = inner.L_level(z_H, z_L, **seq_info) + p = inner.lm_head(z_H)[:, pe:].float().argmax(-1) # decode answer THIS step + if prev is None: + adrift_full.append(torch.zeros(B, device=device)); adrift_ans.append(torch.zeros(B, device=device)) + ldrift.append(torch.zeros(B, device=device)) + else: + adrift_full.append((p != prev).float().sum(-1)) + adrift_ans.append(((p != prev) & ans_mask).float().sum(-1)) + ldrift.append((z_H - prev_zH).float().flatten(1).norm(dim=1)) # latent z_H drift + prev = p; prev_zH = z_H.detach() + preds = prev + mask = labels > 0 + exact = ((preds == labels) | ~mask).all(-1) + preds_all.append(preds.cpu().numpy()); labels_all.append(labels.cpu().numpy()) + inputs_all.append(batch["inputs"].cpu().numpy()); exact_all.append(exact.cpu().numpy()) + idx_all.append(test["idx"][s:e]) + ans_drift_full_all.append(torch.stack(adrift_full, 1).cpu().numpy()) + ans_drift_ans_all.append(torch.stack(adrift_ans, 1).cpu().numpy()) + ldrift_all.append(torch.stack(ldrift, 1).cpu().numpy()) + print(f" [{e}/{n}] exact={exact.float().mean():.3f}", flush=True) + + np.savez_compressed(args.out, + preds=np.concatenate(preds_all), labels=np.concatenate(labels_all), + inputs=np.concatenate(inputs_all), exact_correct=np.concatenate(exact_all).astype(np.float32), + idx=np.concatenate(idx_all), + ans_drift_full=np.concatenate(ans_drift_full_all), # (N, steps) decoded-answer Hamming drift + ans_drift_ans=np.concatenate(ans_drift_ans_all), + drift_zH=np.concatenate(ldrift_all)) # (N, steps) over solution-space cells only + print("saved", args.out) + + +if __name__ == "__main__": + main() |
