summaryrefslogtreecommitdiff
path: root/maze_pred_dump.py
blob: ebbc735139aaa8eb78a0ea79dbdaba749c310aff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Dump per-cell predictions for a TRM-Maze checkpoint (plain forward, no JVP) so we can
analyze WHERE failure errors are (connected detour = coherent stable wrong path vs scattered).
Saves preds, labels, inputs, exact_correct, idx for n test puzzles.
"""
from __future__ import annotations
import sys, argparse
from pathlib import Path
import numpy as np
import torch

sys.path.insert(0, "/home/yurenh2/rrm/research/flossing")
from diagnose_trm_joint_maze import load_model, load_test_samples  # att+maze-capable loader


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt-root", required=True)
    ap.add_argument("--ckpt-name", default="step_130200")
    ap.add_argument("--data", required=True)
    ap.add_argument("--n", type=int, default=512)
    ap.add_argument("--batch-size", type=int, default=32)
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--out", required=True)
    args = ap.parse_args()
    device = "cuda"
    model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device)
    inner = model.inner
    test = load_test_samples(Path(args.data), args.n, 0, 1, args.seed)
    n = len(test["inputs"])
    pe = inner.puzzle_emb_len

    preds_all, labels_all, inputs_all, exact_all, idx_all = [], [], [], [], []
    ans_drift_full_all, ans_drift_ans_all, ldrift_all = [], [], []
    for s in range(0, n, args.batch_size):
        e = min(s + args.batch_size, n)
        batch = {k: test[k][s:e].to(device) for k in ["inputs", "labels", "puzzle_identifiers"]}
        B = batch["inputs"].shape[0]
        seq_full = inner.config.seq_len + pe
        hidden = inner.config.hidden_size
        with torch.no_grad():
            z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
            z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
            seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None)
            inp_emb = inner._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
            labels = batch["labels"]
            ans_mask = (labels != batch["inputs"])  # SOLUTION-SPACE cells: where task requires change
            prev = None; prev_zH = None
            adrift_full, adrift_ans, ldrift = [], [], []   # answer Hamming drift + LATENT z_H drift
            for _ in range(inner.config.halt_max_steps):
                for _h in range(inner.config.H_cycles):
                    for _l in range(inner.config.L_cycles):
                        z_L = inner.L_level(z_L, z_H + inp_emb, **seq_info)
                    z_H = inner.L_level(z_H, z_L, **seq_info)
                p = inner.lm_head(z_H)[:, pe:].float().argmax(-1)   # decode answer THIS step
                if prev is None:
                    adrift_full.append(torch.zeros(B, device=device)); adrift_ans.append(torch.zeros(B, device=device))
                    ldrift.append(torch.zeros(B, device=device))
                else:
                    adrift_full.append((p != prev).float().sum(-1))
                    adrift_ans.append(((p != prev) & ans_mask).float().sum(-1))
                    ldrift.append((z_H - prev_zH).float().flatten(1).norm(dim=1))   # latent z_H drift
                prev = p; prev_zH = z_H.detach()
            preds = prev
            mask = labels > 0
            exact = ((preds == labels) | ~mask).all(-1)
        preds_all.append(preds.cpu().numpy()); labels_all.append(labels.cpu().numpy())
        inputs_all.append(batch["inputs"].cpu().numpy()); exact_all.append(exact.cpu().numpy())
        idx_all.append(test["idx"][s:e])
        ans_drift_full_all.append(torch.stack(adrift_full, 1).cpu().numpy())
        ans_drift_ans_all.append(torch.stack(adrift_ans, 1).cpu().numpy())
        ldrift_all.append(torch.stack(ldrift, 1).cpu().numpy())
        print(f"  [{e}/{n}] exact={exact.float().mean():.3f}", flush=True)

    np.savez_compressed(args.out,
                        preds=np.concatenate(preds_all), labels=np.concatenate(labels_all),
                        inputs=np.concatenate(inputs_all), exact_correct=np.concatenate(exact_all).astype(np.float32),
                        idx=np.concatenate(idx_all),
                        ans_drift_full=np.concatenate(ans_drift_full_all),   # (N, steps) decoded-answer Hamming drift
                        ans_drift_ans=np.concatenate(ans_drift_ans_all),
                        drift_zH=np.concatenate(ldrift_all))     # (N, steps) over solution-space cells only
    print("saved", args.out)


if __name__ == "__main__":
    main()