diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/ptrm_rollout_selection.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/ptrm_rollout_selection.py')
| -rw-r--r-- | research/flossing/ptrm_rollout_selection.py | 643 |
1 files changed, 643 insertions, 0 deletions
diff --git a/research/flossing/ptrm_rollout_selection.py b/research/flossing/ptrm_rollout_selection.py new file mode 100644 index 0000000..4fb1dd9 --- /dev/null +++ b/research/flossing/ptrm_rollout_selection.py @@ -0,0 +1,643 @@ +"""PTRM-style stochastic rollout evaluation with Q and stability selection. + +This is an inference-time experiment: no training, no weight updates. + +For each input, run K stochastic recursive trajectories by injecting Gaussian +noise into the latent state before every ACT step. Select a trajectory by: + - Q head score (PTRM) + - finite-difference top Lyapunov proxy (lowest lambda) + - finite-difference low-rank Lyapunov spectrum proxies + - simple Q/lambda hybrid scores + +The Lyapunov proxy is computed by pairing each rollout with a tiny shadow +trajectory that receives the same stochastic noise and is renormalized after +each ACT step. This is much cheaper than JVP-based exact spectrum estimation +and is enough to test whether stability can act as a free selector. + +The optional spectrum proxy generalizes the shadow trajectory to k orthogonal +shadows and uses QR re-orthogonalization after every ACT step. This estimates +the top-k finite-time spectrum in a random tangent subspace. It is much more +expensive than top-1 because the model batch is multiplied by k + 1. +""" +from __future__ import annotations + +import argparse +import csv +import json +import math +import sys +from dataclasses import replace +from pathlib import Path +from typing import Any + +import numpy as np +import torch + +TRM_DIR = Path("/home/yurenh2/rrm/trm") +sys.path.insert(0, str(TRM_DIR)) + +from models.recursive_reasoning.trm import ( # noqa: E402 + TinyRecursiveReasoningModel_ACTV1, + TinyRecursiveReasoningModel_ACTV1InnerCarry, +) + + +IGNORE_LABEL_ID = -100 + + +def load_model(ckpt_root: Path, ckpt_name: str, device: str): + cfg = json.loads(json.dumps(__import__("yaml").safe_load((ckpt_root / "all_config.yaml").read_text()))) + train_meta = json.loads((Path(cfg["data_paths"][0]) / "train" / "dataset.json").read_text()) + + arch_cfg = dict(cfg["arch"]) + arch_cfg.update( + batch_size=cfg["global_batch_size"], + seq_len=train_meta["seq_len"], + vocab_size=train_meta["vocab_size"], + num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], + ) + + model = TinyRecursiveReasoningModel_ACTV1(arch_cfg) + state = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) + stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in state.items()} + missing, unexpected = model.load_state_dict(stripped, strict=False) + print(f"[load] {ckpt_root.name}/{ckpt_name} missing={len(missing)} unexpected={len(unexpected)}") + if missing[:3]: + print(f"[load] sample missing: {missing[:3]}") + if unexpected[:3]: + print(f"[load] sample unexpected: {unexpected[:3]}") + model.to(device).eval() + return model, cfg + + +def load_test_samples(data_path: Path, n_samples: int, seed: int): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "test" / "all__inputs.npy") + labels = np.load(data_path / "test" / "all__labels.npy") + puzzle_ids = np.load(data_path / "test" / "all__puzzle_identifiers.npy") + + n = min(n_samples, len(inputs)) + idx = rng.choice(len(inputs), size=n, replace=False) + return { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), + "labels": torch.from_numpy(labels[idx].astype(np.int32)), + "puzzle_identifiers": torch.from_numpy(puzzle_ids[idx].astype(np.int32)), + "idx": idx, + } + + +def batch_slice(samples: dict[str, Any], start: int, end: int, device: str): + return { + k: v[start:end].to(device, non_blocking=True) + for k, v in samples.items() + if k in ("inputs", "labels", "puzzle_identifiers") + } + + +def repeat_batch(batch: dict[str, torch.Tensor], repeats: int): + if repeats == 1: + return batch + return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()} + + +def cat_batches(a: dict[str, torch.Tensor], b: dict[str, torch.Tensor]): + return {k: torch.cat([a[k], b[k]], dim=0) for k in a} + + +def _rand_unit_like(inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, generator: torch.Generator): + dh = torch.randn(inner.z_H.shape, device=inner.z_H.device, dtype=torch.float32, generator=generator) + dl = torch.randn(inner.z_L.shape, device=inner.z_L.device, dtype=torch.float32, generator=generator) + norm = torch.sqrt(dh.flatten(1).square().sum(-1) + dl.flatten(1).square().sum(-1)).clamp_min(1e-30) + view_h = (dh.shape[0],) + (1,) * (dh.ndim - 1) + view_l = (dl.shape[0],) + (1,) * (dl.ndim - 1) + return (dh / norm.view(view_h)).to(inner.z_H.dtype), (dl / norm.view(view_l)).to(inner.z_L.dtype) + + +def _q_to_dirs( + q: torch.Tensor, + z_h_shape: torch.Size, + z_l_shape: torch.Size, + h_dtype: torch.dtype, + l_dtype: torch.dtype, +): + total, _dim, spec_k = q.shape + h_numel = math.prod(z_h_shape) + q_t = q.transpose(1, 2).contiguous() + h_dirs = q_t[:, :, :h_numel].reshape((total, spec_k) + tuple(z_h_shape)).to(h_dtype) + l_dirs = q_t[:, :, h_numel:].reshape((total, spec_k) + tuple(z_l_shape)).to(l_dtype) + return h_dirs, l_dirs + + +def _dirs_to_q(h_dirs: torch.Tensor, l_dirs: torch.Tensor): + q_t = torch.cat([h_dirs.float().flatten(2), l_dirs.float().flatten(2)], dim=2) + return q_t.transpose(1, 2).contiguous() + + +def _rand_orthonormal_dirs_like( + inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, + spec_k: int, + generator: torch.Generator, +): + total = inner.z_H.shape[0] + h_dirs = torch.randn( + (total, spec_k) + tuple(inner.z_H.shape[1:]), + device=inner.z_H.device, + dtype=torch.float32, + generator=generator, + ) + l_dirs = torch.randn( + (total, spec_k) + tuple(inner.z_L.shape[1:]), + device=inner.z_L.device, + dtype=torch.float32, + generator=generator, + ) + q, _ = torch.linalg.qr(_dirs_to_q(h_dirs, l_dirs), mode="reduced") + return _q_to_dirs(q, inner.z_H.shape[1:], inner.z_L.shape[1:], inner.z_H.dtype, inner.z_L.dtype) + + +def _make_spectrum_shadows( + main: TinyRecursiveReasoningModel_ACTV1InnerCarry, + h_dirs: torch.Tensor, + l_dirs: torch.Tensor, + eps: float, +): + total, spec_k = h_dirs.shape[:2] + z_h = main.z_H[:, None] + eps * h_dirs.to(main.z_H.dtype) + z_l = main.z_L[:, None] + eps * l_dirs.to(main.z_L.dtype) + return TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=z_h.reshape((total * spec_k,) + tuple(main.z_H.shape[1:])).detach(), + z_L=z_l.reshape((total * spec_k,) + tuple(main.z_L.shape[1:])).detach(), + ) + + +def _repeat_inner_batch(batch: dict[str, torch.Tensor], repeats: int): + return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()} + + +def _cat_many_batches(batches: list[dict[str, torch.Tensor]]): + return {k: torch.cat([b[k] for b in batches], dim=0) for k in batches[0]} + + +def _split_inner(inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, n: int): + return ( + TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=inner.z_H[:n].contiguous(), + z_L=inner.z_L[:n].contiguous(), + ), + TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=inner.z_H[n:].contiguous(), + z_L=inner.z_L[n:].contiguous(), + ), + ) + + +def _split_spectrum_inner(inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, n_main: int): + return ( + TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=inner.z_H[:n_main].contiguous(), + z_L=inner.z_L[:n_main].contiguous(), + ), + TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=inner.z_H[n_main:].contiguous(), + z_L=inner.z_L[n_main:].contiguous(), + ), + ) + + +def _cat_inner(a: TinyRecursiveReasoningModel_ACTV1InnerCarry, b: TinyRecursiveReasoningModel_ACTV1InnerCarry): + return TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=torch.cat([a.z_H, b.z_H], dim=0), + z_L=torch.cat([a.z_L, b.z_L], dim=0), + ) + + +def _sample_noise( + shape: torch.Size, + std: float, + generator: torch.Generator, + dtype: torch.dtype, + device: torch.device, +): + if std <= 0: + return torch.zeros(shape, device=device, dtype=dtype) + return (std * torch.randn(shape, device=device, dtype=torch.float32, generator=generator)).to(dtype) + + +def _apply_step_noise( + inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, + noise_h: torch.Tensor, + noise_l: torch.Tensor, + perturb: str, +): + z_h, z_l = inner.z_H, inner.z_L + if perturb in ("h", "both"): + z_h = z_h + noise_h + if perturb in ("l", "both"): + z_l = z_l + noise_l + return replace(inner, z_H=z_h, z_L=z_l) + + +def _separation( + main: TinyRecursiveReasoningModel_ACTV1InnerCarry, + shadow: TinyRecursiveReasoningModel_ACTV1InnerCarry, +): + dh = (shadow.z_H.float() - main.z_H.float()).flatten(1) + dl = (shadow.z_L.float() - main.z_L.float()).flatten(1) + return torch.sqrt(dh.square().sum(-1) + dl.square().sum(-1)).clamp_min(1e-30) + + +def _renormalize_shadow( + main: TinyRecursiveReasoningModel_ACTV1InnerCarry, + shadow: TinyRecursiveReasoningModel_ACTV1InnerCarry, + eps: float, +): + sep = _separation(main, shadow) + view_h = (sep.shape[0],) + (1,) * (main.z_H.ndim - 1) + view_l = (sep.shape[0],) + (1,) * (main.z_L.ndim - 1) + scale_h = (eps / sep).view(view_h).to(main.z_H.dtype) + scale_l = (eps / sep).view(view_l).to(main.z_L.dtype) + return TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=(main.z_H + (shadow.z_H - main.z_H) * scale_h).detach(), + z_L=(main.z_L + (shadow.z_L - main.z_L) * scale_l).detach(), + ) + + +@torch.inference_mode() +def deterministic_eval(model, batch: dict[str, torch.Tensor]): + with torch.device(batch["inputs"].device): + carry = model.initial_carry(batch) + logits = None + q_halt = None + steps = 0 + while True: + carry, outputs = model(carry=carry, batch=batch) + logits = outputs["logits"] + q_halt = outputs["q_halt_logits"] + steps += 1 + if bool(carry.halted.all()): + break + exact, token_acc = correctness(logits, batch["labels"]) + return exact, token_acc, q_halt, steps + + +def correctness(logits: torch.Tensor, labels: torch.Tensor): + preds = logits.argmax(dim=-1) + mask = labels != IGNORE_LABEL_ID + exact = torch.where(mask, preds == labels, True).all(dim=-1) + denom = mask.sum(-1).clamp_min(1) + token_acc = ((preds == labels) & mask).sum(-1).float() / denom.float() + return exact, token_acc + + +@torch.inference_mode() +def ptrm_rollouts( + model, + batch: dict[str, torch.Tensor], + rollouts: int, + steps: int, + noise_std: float, + include_clean: bool, + perturb: str, + fd_lyap: bool, + fd_spectrum_k: int, + fd_eps: float, + generator: torch.Generator, +): + device = batch["inputs"].device + base_batch_size = batch["inputs"].shape[0] + expanded = repeat_batch(batch, rollouts) + total = expanded["inputs"].shape[0] + rollout_id = torch.arange(total, device=device) % rollouts + + with torch.device(device): + carry = model.initial_carry(expanded) + reset = torch.ones_like(carry.halted) + main = model.inner.reset_carry(reset, carry.inner_carry) + + shadow = None + lyap_sum = None + spec_shadows = None + spec_h_dirs = None + spec_l_dirs = None + lyap_spec_sum = None + if fd_spectrum_k > 0: + spec_h_dirs, spec_l_dirs = _rand_orthonormal_dirs_like(main, fd_spectrum_k, generator) + spec_shadows = _make_spectrum_shadows(main, spec_h_dirs, spec_l_dirs, fd_eps) + lyap_spec_sum = torch.zeros(total, fd_spectrum_k, device=device, dtype=torch.float32) + elif fd_lyap: + dh, dl = _rand_unit_like(main, generator) + shadow = TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=(main.z_H + fd_eps * dh).detach(), + z_L=(main.z_L + fd_eps * dl).detach(), + ) + lyap_sum = torch.zeros(total, device=device, dtype=torch.float32) + + logits = None + q_halt = None + q_continue = None + for _ in range(steps): + noise_h = _sample_noise(main.z_H.shape, noise_std, generator, main.z_H.dtype, device) + noise_l = _sample_noise(main.z_L.shape, noise_std, generator, main.z_L.dtype, device) + if include_clean and rollouts > 1: + clean_mask = (rollout_id == 0).view((-1,) + (1,) * (main.z_H.ndim - 1)) + noise_h = torch.where(clean_mask, torch.zeros_like(noise_h), noise_h) + noise_l = torch.where(clean_mask, torch.zeros_like(noise_l), noise_l) + + main = _apply_step_noise(main, noise_h, noise_l, perturb) + if fd_spectrum_k > 0: + assert spec_shadows is not None and lyap_spec_sum is not None + shadow_noise_h = noise_h.repeat_interleave(fd_spectrum_k, dim=0) + shadow_noise_l = noise_l.repeat_interleave(fd_spectrum_k, dim=0) + spec_shadows = _apply_step_noise(spec_shadows, shadow_noise_h, shadow_noise_l, perturb) + combined_inner = _cat_inner(main, spec_shadows) + combined_batch = _cat_many_batches([expanded, _repeat_inner_batch(expanded, fd_spectrum_k)]) + combined_inner, combined_logits, (combined_q_halt, combined_q_continue) = model.inner(combined_inner, combined_batch) + main, spec_shadows = _split_spectrum_inner(combined_inner, total) + logits = combined_logits[:total] + q_halt = combined_q_halt[:total] + q_continue = combined_q_continue[:total] + + delta_h = ( + spec_shadows.z_H.reshape((total, fd_spectrum_k) + tuple(main.z_H.shape[1:])).float() + - main.z_H[:, None].float() + ) / fd_eps + delta_l = ( + spec_shadows.z_L.reshape((total, fd_spectrum_k) + tuple(main.z_L.shape[1:])).float() + - main.z_L[:, None].float() + ) / fd_eps + q, r = torch.linalg.qr(_dirs_to_q(delta_h, delta_l), mode="reduced") + diag = torch.diagonal(r, dim1=-2, dim2=-1).abs().clamp_min(1e-30) + lyap_spec_sum = lyap_spec_sum + torch.log(diag).float() + spec_h_dirs, spec_l_dirs = _q_to_dirs( + q, main.z_H.shape[1:], main.z_L.shape[1:], main.z_H.dtype, main.z_L.dtype + ) + spec_shadows = _make_spectrum_shadows(main, spec_h_dirs, spec_l_dirs, fd_eps) + elif fd_lyap: + assert shadow is not None + shadow = _apply_step_noise(shadow, noise_h, noise_l, perturb) + combined_inner = _cat_inner(main, shadow) + combined_batch = cat_batches(expanded, expanded) + combined_inner, combined_logits, (combined_q_halt, combined_q_continue) = model.inner(combined_inner, combined_batch) + main, shadow = _split_inner(combined_inner, total) + logits = combined_logits[:total] + q_halt = combined_q_halt[:total] + q_continue = combined_q_continue[:total] + sep = _separation(main, shadow) + lyap_sum = lyap_sum + torch.log(sep / fd_eps).float() # type: ignore[operator] + shadow = _renormalize_shadow(main, shadow, fd_eps) + else: + main, logits, (q_halt, q_continue) = model.inner(main, expanded) + + assert logits is not None and q_halt is not None + exact, token_acc = correctness(logits, expanded["labels"]) + exact = exact.view(base_batch_size, rollouts) + token_acc = token_acc.view(base_batch_size, rollouts) + q_halt = q_halt.float().view(base_batch_size, rollouts) + q_continue = q_continue.float().view(base_batch_size, rollouts) if q_continue is not None else torch.zeros_like(q_halt) + lyap = None + lyap_spec = None + if fd_spectrum_k > 0: + assert lyap_spec_sum is not None + lyap_spec = (lyap_spec_sum / max(steps, 1)).view(base_batch_size, rollouts, fd_spectrum_k) + lyap_spec = torch.sort(lyap_spec, dim=-1, descending=True).values + lyap = lyap_spec[..., 0] + elif fd_lyap: + assert lyap_sum is not None + lyap = (lyap_sum / max(steps, 1)).view(base_batch_size, rollouts) + return exact, token_acc, q_halt, q_continue, lyap, lyap_spec + + +def _take_by_idx(values: torch.Tensor, idx: torch.Tensor): + return values.gather(1, idx[:, None]).squeeze(1) + + +def _zscore_per_row(values: torch.Tensor): + return (values - values.mean(dim=1, keepdim=True)) / values.std(dim=1, keepdim=True).clamp_min(1e-6) + + +def summarize_selectors(exact, token_acc, q_halt, lyap, lyap_spec=None): + out: dict[str, float] = {} + bsz, rollouts = exact.shape + arange = torch.arange(bsz, device=exact.device) + correct_counts = exact.float().sum(dim=1) + + selectors = { + "rollout0": torch.zeros(bsz, device=exact.device, dtype=torch.long), + "q_max": q_halt.argmax(dim=1), + "oracle_pass": None, + } + if lyap is not None: + selectors["lyap_min"] = lyap.argmin(dim=1) + qz = _zscore_per_row(q_halt) + lz = _zscore_per_row(lyap) + for alpha in (0.25, 0.5, 1.0, 2.0): + selectors[f"q_minus_{alpha:g}lambda"] = (qz - alpha * lz).argmax(dim=1) + if lyap_spec is not None: + spec_pos_mass = lyap_spec.clamp_min(0).sum(dim=-1) + spec_pos_l2 = lyap_spec.clamp_min(0).square().mean(dim=-1).sqrt() + spec_mean = lyap_spec.mean(dim=-1) + spec_count_pos = (lyap_spec > 0).float().sum(dim=-1) + spec_spread = lyap_spec[..., 0] - lyap_spec[..., -1] + + selectors["spec_pos_mass_min"] = spec_pos_mass.argmin(dim=1) + selectors["spec_pos_l2_min"] = spec_pos_l2.argmin(dim=1) + selectors["spec_mean_min"] = spec_mean.argmin(dim=1) + selectors["spec_count_pos_min"] = spec_count_pos.argmin(dim=1) + selectors["spec_spread_min"] = spec_spread.argmin(dim=1) + + for name, idx in selectors.items(): + if idx is None: + out[f"{name}/exact"] = exact.any(dim=1).float().mean().item() + out[f"{name}/token_acc"] = token_acc.max(dim=1).values.mean().item() + else: + out[f"{name}/exact"] = exact[arange, idx].float().mean().item() + out[f"{name}/token_acc"] = token_acc[arange, idx].mean().item() + + out["mean_rollout/exact"] = exact.float().mean().item() + out["mean_rollout/token_acc"] = token_acc.mean().item() + out["correct_count/mean"] = correct_counts.mean().item() + out["correct_count/std"] = correct_counts.std(unbiased=False).item() + out["correct_count/median"] = correct_counts.median().item() + out["correct_count/q10"] = torch.quantile(correct_counts, 0.10).item() + out["correct_count/q25"] = torch.quantile(correct_counts, 0.25).item() + out["correct_count/q75"] = torch.quantile(correct_counts, 0.75).item() + out["correct_count/q90"] = torch.quantile(correct_counts, 0.90).item() + out["correct_count/zero_frac"] = (correct_counts == 0).float().mean().item() + out["correct_count/full_frac"] = (correct_counts == rollouts).float().mean().item() + for threshold in (1, 5, 10, 25, 50, 75, 90): + if threshold <= rollouts: + out[f"correct_count/ge_{threshold}_frac"] = (correct_counts >= threshold).float().mean().item() + out["q_mean"] = q_halt.mean().item() + if lyap is not None: + out["lambda_mean"] = lyap.mean().item() + if exact.any().item() and (~exact).any().item(): + out["lambda_success_mean"] = lyap[exact].mean().item() + out["lambda_fail_mean"] = lyap[~exact].mean().item() + out["q_success_mean"] = q_halt[exact].mean().item() + out["q_fail_mean"] = q_halt[~exact].mean().item() + if lyap_spec is not None: + spec_pos_mass = lyap_spec.clamp_min(0).sum(dim=-1) + spec_pos_l2 = lyap_spec.clamp_min(0).square().mean(dim=-1).sqrt() + spec_mean = lyap_spec.mean(dim=-1) + spec_count_pos = (lyap_spec > 0).float().sum(dim=-1) + spec_spread = lyap_spec[..., 0] - lyap_spec[..., -1] + out["spec_k"] = float(lyap_spec.shape[-1]) + out["spec_pos_mass_mean"] = spec_pos_mass.mean().item() + out["spec_pos_l2_mean"] = spec_pos_l2.mean().item() + out["spec_mean_mean"] = spec_mean.mean().item() + out["spec_count_pos_mean"] = spec_count_pos.mean().item() + out["spec_spread_mean"] = spec_spread.mean().item() + if exact.any().item() and (~exact).any().item(): + out["spec_pos_mass_success_mean"] = spec_pos_mass[exact].mean().item() + out["spec_pos_mass_fail_mean"] = spec_pos_mass[~exact].mean().item() + out["spec_mean_success_mean"] = spec_mean[exact].mean().item() + out["spec_mean_fail_mean"] = spec_mean[~exact].mean().item() + out["spec_count_pos_success_mean"] = spec_count_pos[exact].mean().item() + out["spec_count_pos_fail_mean"] = spec_count_pos[~exact].mean().item() + return out + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt-root", required=True) + parser.add_argument("--ckpt-name", default="step_260410") + parser.add_argument("--n-samples", type=int, default=512) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--rollouts", type=int, default=8) + parser.add_argument("--steps", type=int, default=16) + parser.add_argument("--noise-std", type=float, default=1e-3) + parser.add_argument("--include-clean", action="store_true") + parser.add_argument("--perturb", choices=["h", "l", "both"], default="both") + parser.add_argument("--fd-lyap", action="store_true") + parser.add_argument("--fd-spectrum-k", type=int, default=0) + parser.add_argument("--fd-eps", type=float, default=1e-2) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--out-prefix", default="research/flossing/ptrm_selection") + args = parser.parse_args() + + device = "cuda" + torch.manual_seed(args.seed) + generator = torch.Generator(device=device).manual_seed(args.seed + 12345) + + ckpt_root = Path(args.ckpt_root) + model, cfg = load_model(ckpt_root, args.ckpt_name, device) + samples = load_test_samples(Path(cfg["data_paths"][0]), args.n_samples, args.seed) + n = len(samples["inputs"]) + + all_det_exact, all_det_token = [], [] + all_exact, all_token, all_q, all_q_continue, all_lam, all_spec = [], [], [], [], [], [] + + for start in range(0, n, args.batch_size): + end = min(start + args.batch_size, n) + batch = batch_slice(samples, start, end, device) + det_exact, det_token, _det_q, det_steps = deterministic_eval(model, batch) + exact, token_acc, q_halt, q_continue, lyap, lyap_spec = ptrm_rollouts( + model=model, + batch=batch, + rollouts=args.rollouts, + steps=args.steps, + noise_std=args.noise_std, + include_clean=args.include_clean, + perturb=args.perturb, + fd_lyap=args.fd_lyap, + fd_spectrum_k=args.fd_spectrum_k, + fd_eps=args.fd_eps, + generator=generator, + ) + all_det_exact.append(det_exact.cpu()) + all_det_token.append(det_token.cpu()) + all_exact.append(exact.cpu()) + all_token.append(token_acc.cpu()) + all_q.append(q_halt.cpu()) + all_q_continue.append(q_continue.cpu()) + if lyap is not None: + all_lam.append(lyap.cpu()) + if lyap_spec is not None: + all_spec.append(lyap_spec.cpu()) + print( + f"[{end}/{n}] det={det_exact.float().mean().item():.4f} " + f"q_sel={_take_by_idx(exact, q_halt.argmax(1)).float().mean().item():.4f} " + f"pass@K={exact.any(1).float().mean().item():.4f} steps={det_steps}", + flush=True, + ) + + det_exact = torch.cat(all_det_exact) + det_token = torch.cat(all_det_token) + exact = torch.cat(all_exact) + token_acc = torch.cat(all_token) + q_halt = torch.cat(all_q) + q_continue = torch.cat(all_q_continue) + lyap = torch.cat(all_lam) if all_lam else None + lyap_spec = torch.cat(all_spec) if all_spec else None + summary = summarize_selectors(exact, token_acc, q_halt, lyap, lyap_spec) + summary["deterministic/exact"] = det_exact.float().mean().item() + summary["deterministic/token_acc"] = det_token.mean().item() + correct_counts = exact.float().sum(dim=1) + oracle_success = exact.any(dim=1) + q_selected = exact[torch.arange(exact.shape[0]), q_halt.argmax(dim=1)] + det_success = det_exact.bool() + det_fail = ~det_success + if det_success.any().item(): + summary["correct_count/det_success_mean"] = correct_counts[det_success].mean().item() + summary["oracle_pass/det_success_frac"] = oracle_success[det_success].float().mean().item() + summary["q_max/det_success_frac"] = q_selected[det_success].float().mean().item() + if det_fail.any().item(): + summary["correct_count/det_fail_mean"] = correct_counts[det_fail].mean().item() + summary["oracle_pass/det_fail_frac"] = oracle_success[det_fail].float().mean().item() + summary["q_max/det_fail_frac"] = q_selected[det_fail].float().mean().item() + summary["n_samples"] = float(n) + summary["rollouts"] = float(args.rollouts) + summary["noise_std"] = float(args.noise_std) + summary["include_clean"] = float(args.include_clean) + summary["fd_lyap"] = float(args.fd_lyap) + summary["fd_spectrum_k"] = float(args.fd_spectrum_k) + summary["steps"] = float(args.steps) + summary["perturb_l"] = float(args.perturb == "l") + summary["perturb_h"] = float(args.perturb == "h") + summary["perturb_both"] = float(args.perturb == "both") + + out_prefix = Path(args.out_prefix) + out_prefix.parent.mkdir(parents=True, exist_ok=True) + meta = { + "ckpt_root": str(ckpt_root), + "ckpt_name": args.ckpt_name, + "n_samples": n, + "batch_size": args.batch_size, + "rollouts": args.rollouts, + "steps": args.steps, + "noise_std": args.noise_std, + "include_clean": args.include_clean, + "perturb": args.perturb, + "fd_lyap": args.fd_lyap, + "fd_spectrum_k": args.fd_spectrum_k, + "fd_eps": args.fd_eps, + "seed": args.seed, + } + np.savez_compressed( + f"{out_prefix}.npz", + idx=samples["idx"], + det_exact=det_exact.numpy(), + det_token_acc=det_token.numpy(), + exact=exact.numpy(), + token_acc=token_acc.numpy(), + q_halt=q_halt.numpy(), + q_continue=q_continue.numpy(), + lyap=np.asarray([]) if lyap is None else lyap.numpy(), + lyap_spec=np.asarray([]) if lyap_spec is None else lyap_spec.numpy(), + meta_json=np.asarray(json.dumps(meta, sort_keys=True)), + ) + with open(f"{out_prefix}.meta.json", "w") as f: + json.dump(meta, f, indent=2, sort_keys=True) + with open(f"{out_prefix}.summary.csv", "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=sorted(summary)) + writer.writeheader() + writer.writerow(summary) + + print("\nsummary") + for key in sorted(summary): + print(f"{key}: {summary[key]}") + print(f"\nsaved {out_prefix}.npz and {out_prefix}.summary.csv") + + +if __name__ == "__main__": + main() |
