diff options
Diffstat (limited to 'research/flossing/analyze_tangent_modes.py')
| -rw-r--r-- | research/flossing/analyze_tangent_modes.py | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/research/flossing/analyze_tangent_modes.py b/research/flossing/analyze_tangent_modes.py new file mode 100644 index 0000000..5801ad8 --- /dev/null +++ b/research/flossing/analyze_tangent_modes.py @@ -0,0 +1,143 @@ +"""(a) lite: For each test sample, save the final tangent basis Q (top-k modes after +running through the full inference). Compute position/hidden activity profiles per +mode and compare success vs failure groups. +""" +from __future__ import annotations +import sys, os, yaml, json, time, argparse +from pathlib import Path +import numpy as np +import torch + +HRM_DIR = Path("/home/yurenh2/rrm/hrm") +sys.path.insert(0, str(HRM_DIR)) +from models.hrm.hrm_act_v1 import HierarchicalReasoningModel_ACTV1 + + +def load_model(ckpt_root, ckpt_name, device): + cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) + arch_cfg = dict(cfg["arch"]) + train_meta = json.loads((Path(cfg["data_path"]) / "train" / "dataset.json").read_text()) + arch_cfg.update(batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"], + vocab_size=train_meta["vocab_size"], + num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], causal=False) + model = HierarchicalReasoningModel_ACTV1(arch_cfg) + sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) + stripped = {} + for k, v in sd.items(): + nk = k + for p in ("_orig_mod.", "model."): + if nk.startswith(p): nk = nk[len(p):] + stripped[nk] = v + model.load_state_dict(stripped, strict=False) + model.to(device).eval() + return model, cfg, train_meta + + +def jvp_one(f, x, v): + return torch.autograd.functional.jvp(f, x, v=v, create_graph=False, strict=False) + + +def run_save_final_Q(model, batch, k_lyap, device, seed): + """Run inference with QR-iteration on top-k tangents; return final Q (B, seq, hidden, k) + after all ACT steps. Also return exact_correct, predicted_logits. + """ + inner = model.inner + cfg = inner.config + B = batch["inputs"].shape[0] + seq_full = cfg.seq_len + inner.puzzle_emb_len + hidden = cfg.hidden_size + state_dim = seq_full * hidden + + z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None) + input_embeddings = inner._input_embeddings(batch["inputs"].to(device), + batch["puzzle_identifiers"].to(device)) + + g = torch.Generator(device=device).manual_seed(seed) + Q0 = torch.randn(B, state_dim, k_lyap, device=device, dtype=torch.float32, generator=g) + Q, _ = torch.linalg.qr(Q0) + + with torch.enable_grad(): + for _act in range(cfg.halt_max_steps): + zH = z_H.detach(); zL = z_L.detach() + for _h in range(cfg.H_cycles): + for _l in range(cfg.L_cycles): + out = [] + fx_last = None + f = lambda x: inner.L_level(x, zH + input_embeddings, **seq_info) + for i in range(k_lyap): + v_i = Q[..., i].view_as(zL) + fx, Dv = jvp_one(f, zL, v_i) + out.append(Dv.reshape(B, state_dim).to(torch.float32)) + fx_last = fx + Q = torch.stack(out, dim=-1) + zL = fx_last + Q, R = torch.linalg.qr(Q) + out = [] + f = lambda x: inner.H_level(x, zL, **seq_info) + for i in range(k_lyap): + v_i = Q[..., i].view_as(zH) + fx, Dv = jvp_one(f, zH, v_i) + out.append(Dv.reshape(B, state_dim).to(torch.float32)) + fx_last = fx + Q = torch.stack(out, dim=-1) + zH = fx_last + Q, R = torch.linalg.qr(Q) + z_H, z_L = zH, zL + + with torch.no_grad(): + output = inner.lm_head(z_H)[:, inner.puzzle_emb_len:].float() + preds = output.argmax(dim=-1) + labels = batch["labels"].to(device) + mask = labels > 0 + exact = ((preds == labels) | ~mask).all(dim=-1).cpu().float().numpy() + + Q_final = Q.reshape(B, seq_full, hidden, k_lyap).cpu().float().numpy() + return Q_final, exact + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--ckpt-root", required=True) + ap.add_argument("--ckpt-name", default="step_26040") + ap.add_argument("--n-samples", type=int, default=256) + ap.add_argument("--batch-size", type=int, default=32) + ap.add_argument("--k-lyap", type=int, default=4) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--out", default="tangent_modes.npz") + args = ap.parse_args() + device = "cuda" + + model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device) + + rng = np.random.default_rng(args.seed) + data_path = Path(cfg["data_path"]) + inputs = np.load(data_path / "test" / "all__inputs.npy") + labels = np.load(data_path / "test" / "all__labels.npy") + pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy") + idx = rng.choice(len(inputs), size=args.n_samples, replace=False) + + Q_all = []; exact_all = [] + t0 = time.time() + for s in range(0, args.n_samples, args.batch_size): + e = min(s + args.batch_size, args.n_samples) + bidx = idx[s:e] + batch = { + "inputs": torch.from_numpy(inputs[bidx].astype(np.int32)), + "labels": torch.from_numpy(labels[bidx].astype(np.int32)), + "puzzle_identifiers": torch.from_numpy(pid[bidx].astype(np.int32)), + } + Q_final, exact = run_save_final_Q(model, batch, args.k_lyap, device, seed=args.seed + s) + Q_all.append(Q_final); exact_all.append(exact) + print(f" [{e}/{args.n_samples}] dt={time.time()-t0:.1f}s exact={exact.mean():.3f}", flush=True) + + Q_all = np.concatenate(Q_all, axis=0) # (N, seq, hidden, k) + exact_all = np.concatenate(exact_all, axis=0) # (N,) + print(f"saved shape Q={Q_all.shape}, exact={exact_all.shape}, acc={exact_all.mean():.3f}") + np.savez_compressed(args.out, Q_final=Q_all, exact_correct=exact_all, sample_idx=idx) + print(f"saved → {args.out}") + + +if __name__ == "__main__": + main() |
