"""PTRM-style stochastic rollout evaluation with Q and stability selection. This is an inference-time experiment: no training, no weight updates. For each input, run K stochastic recursive trajectories by injecting Gaussian noise into the latent state before every ACT step. Select a trajectory by: - Q head score (PTRM) - finite-difference top Lyapunov proxy (lowest lambda) - finite-difference low-rank Lyapunov spectrum proxies - simple Q/lambda hybrid scores The Lyapunov proxy is computed by pairing each rollout with a tiny shadow trajectory that receives the same stochastic noise and is renormalized after each ACT step. This is much cheaper than JVP-based exact spectrum estimation and is enough to test whether stability can act as a free selector. The optional spectrum proxy generalizes the shadow trajectory to k orthogonal shadows and uses QR re-orthogonalization after every ACT step. This estimates the top-k finite-time spectrum in a random tangent subspace. It is much more expensive than top-1 because the model batch is multiplied by k + 1. """ from __future__ import annotations import argparse import csv import json import math import sys from dataclasses import replace from pathlib import Path from typing import Any import numpy as np import torch TRM_DIR = Path("/home/yurenh2/rrm/trm") sys.path.insert(0, str(TRM_DIR)) from models.recursive_reasoning.trm import ( # noqa: E402 TinyRecursiveReasoningModel_ACTV1, TinyRecursiveReasoningModel_ACTV1InnerCarry, ) IGNORE_LABEL_ID = -100 def load_model(ckpt_root: Path, ckpt_name: str, device: str): cfg = json.loads(json.dumps(__import__("yaml").safe_load((ckpt_root / "all_config.yaml").read_text()))) train_meta = json.loads((Path(cfg["data_paths"][0]) / "train" / "dataset.json").read_text()) arch_cfg = dict(cfg["arch"]) arch_cfg.update( batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"], vocab_size=train_meta["vocab_size"], num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], ) model = TinyRecursiveReasoningModel_ACTV1(arch_cfg) state = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in state.items()} missing, unexpected = model.load_state_dict(stripped, strict=False) print(f"[load] {ckpt_root.name}/{ckpt_name} missing={len(missing)} unexpected={len(unexpected)}") if missing[:3]: print(f"[load] sample missing: {missing[:3]}") if unexpected[:3]: print(f"[load] sample unexpected: {unexpected[:3]}") model.to(device).eval() return model, cfg def load_test_samples(data_path: Path, n_samples: int, seed: int): rng = np.random.default_rng(seed) inputs = np.load(data_path / "test" / "all__inputs.npy") labels = np.load(data_path / "test" / "all__labels.npy") puzzle_ids = np.load(data_path / "test" / "all__puzzle_identifiers.npy") n = min(n_samples, len(inputs)) idx = rng.choice(len(inputs), size=n, replace=False) return { "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), "labels": torch.from_numpy(labels[idx].astype(np.int32)), "puzzle_identifiers": torch.from_numpy(puzzle_ids[idx].astype(np.int32)), "idx": idx, } def batch_slice(samples: dict[str, Any], start: int, end: int, device: str): return { k: v[start:end].to(device, non_blocking=True) for k, v in samples.items() if k in ("inputs", "labels", "puzzle_identifiers") } def repeat_batch(batch: dict[str, torch.Tensor], repeats: int): if repeats == 1: return batch return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()} def cat_batches(a: dict[str, torch.Tensor], b: dict[str, torch.Tensor]): return {k: torch.cat([a[k], b[k]], dim=0) for k in a} def _rand_unit_like(inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, generator: torch.Generator): dh = torch.randn(inner.z_H.shape, device=inner.z_H.device, dtype=torch.float32, generator=generator) dl = torch.randn(inner.z_L.shape, device=inner.z_L.device, dtype=torch.float32, generator=generator) norm = torch.sqrt(dh.flatten(1).square().sum(-1) + dl.flatten(1).square().sum(-1)).clamp_min(1e-30) view_h = (dh.shape[0],) + (1,) * (dh.ndim - 1) view_l = (dl.shape[0],) + (1,) * (dl.ndim - 1) return (dh / norm.view(view_h)).to(inner.z_H.dtype), (dl / norm.view(view_l)).to(inner.z_L.dtype) def _q_to_dirs( q: torch.Tensor, z_h_shape: torch.Size, z_l_shape: torch.Size, h_dtype: torch.dtype, l_dtype: torch.dtype, ): total, _dim, spec_k = q.shape h_numel = math.prod(z_h_shape) q_t = q.transpose(1, 2).contiguous() h_dirs = q_t[:, :, :h_numel].reshape((total, spec_k) + tuple(z_h_shape)).to(h_dtype) l_dirs = q_t[:, :, h_numel:].reshape((total, spec_k) + tuple(z_l_shape)).to(l_dtype) return h_dirs, l_dirs def _dirs_to_q(h_dirs: torch.Tensor, l_dirs: torch.Tensor): q_t = torch.cat([h_dirs.float().flatten(2), l_dirs.float().flatten(2)], dim=2) return q_t.transpose(1, 2).contiguous() def _rand_orthonormal_dirs_like( inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, spec_k: int, generator: torch.Generator, ): total = inner.z_H.shape[0] h_dirs = torch.randn( (total, spec_k) + tuple(inner.z_H.shape[1:]), device=inner.z_H.device, dtype=torch.float32, generator=generator, ) l_dirs = torch.randn( (total, spec_k) + tuple(inner.z_L.shape[1:]), device=inner.z_L.device, dtype=torch.float32, generator=generator, ) q, _ = torch.linalg.qr(_dirs_to_q(h_dirs, l_dirs), mode="reduced") return _q_to_dirs(q, inner.z_H.shape[1:], inner.z_L.shape[1:], inner.z_H.dtype, inner.z_L.dtype) def _make_spectrum_shadows( main: TinyRecursiveReasoningModel_ACTV1InnerCarry, h_dirs: torch.Tensor, l_dirs: torch.Tensor, eps: float, ): total, spec_k = h_dirs.shape[:2] z_h = main.z_H[:, None] + eps * h_dirs.to(main.z_H.dtype) z_l = main.z_L[:, None] + eps * l_dirs.to(main.z_L.dtype) return TinyRecursiveReasoningModel_ACTV1InnerCarry( z_H=z_h.reshape((total * spec_k,) + tuple(main.z_H.shape[1:])).detach(), z_L=z_l.reshape((total * spec_k,) + tuple(main.z_L.shape[1:])).detach(), ) def _repeat_inner_batch(batch: dict[str, torch.Tensor], repeats: int): return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()} def _cat_many_batches(batches: list[dict[str, torch.Tensor]]): return {k: torch.cat([b[k] for b in batches], dim=0) for k in batches[0]} def _split_inner(inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, n: int): return ( TinyRecursiveReasoningModel_ACTV1InnerCarry( z_H=inner.z_H[:n].contiguous(), z_L=inner.z_L[:n].contiguous(), ), TinyRecursiveReasoningModel_ACTV1InnerCarry( z_H=inner.z_H[n:].contiguous(), z_L=inner.z_L[n:].contiguous(), ), ) def _split_spectrum_inner(inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, n_main: int): return ( TinyRecursiveReasoningModel_ACTV1InnerCarry( z_H=inner.z_H[:n_main].contiguous(), z_L=inner.z_L[:n_main].contiguous(), ), TinyRecursiveReasoningModel_ACTV1InnerCarry( z_H=inner.z_H[n_main:].contiguous(), z_L=inner.z_L[n_main:].contiguous(), ), ) def _cat_inner(a: TinyRecursiveReasoningModel_ACTV1InnerCarry, b: TinyRecursiveReasoningModel_ACTV1InnerCarry): return TinyRecursiveReasoningModel_ACTV1InnerCarry( z_H=torch.cat([a.z_H, b.z_H], dim=0), z_L=torch.cat([a.z_L, b.z_L], dim=0), ) def _sample_noise( shape: torch.Size, std: float, generator: torch.Generator, dtype: torch.dtype, device: torch.device, ): if std <= 0: return torch.zeros(shape, device=device, dtype=dtype) return (std * torch.randn(shape, device=device, dtype=torch.float32, generator=generator)).to(dtype) def _apply_step_noise( inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, noise_h: torch.Tensor, noise_l: torch.Tensor, perturb: str, ): z_h, z_l = inner.z_H, inner.z_L if perturb in ("h", "both"): z_h = z_h + noise_h if perturb in ("l", "both"): z_l = z_l + noise_l return replace(inner, z_H=z_h, z_L=z_l) def _separation( main: TinyRecursiveReasoningModel_ACTV1InnerCarry, shadow: TinyRecursiveReasoningModel_ACTV1InnerCarry, ): dh = (shadow.z_H.float() - main.z_H.float()).flatten(1) dl = (shadow.z_L.float() - main.z_L.float()).flatten(1) return torch.sqrt(dh.square().sum(-1) + dl.square().sum(-1)).clamp_min(1e-30) def _renormalize_shadow( main: TinyRecursiveReasoningModel_ACTV1InnerCarry, shadow: TinyRecursiveReasoningModel_ACTV1InnerCarry, eps: float, ): sep = _separation(main, shadow) view_h = (sep.shape[0],) + (1,) * (main.z_H.ndim - 1) view_l = (sep.shape[0],) + (1,) * (main.z_L.ndim - 1) scale_h = (eps / sep).view(view_h).to(main.z_H.dtype) scale_l = (eps / sep).view(view_l).to(main.z_L.dtype) return TinyRecursiveReasoningModel_ACTV1InnerCarry( z_H=(main.z_H + (shadow.z_H - main.z_H) * scale_h).detach(), z_L=(main.z_L + (shadow.z_L - main.z_L) * scale_l).detach(), ) @torch.inference_mode() def deterministic_eval(model, batch: dict[str, torch.Tensor]): with torch.device(batch["inputs"].device): carry = model.initial_carry(batch) logits = None q_halt = None steps = 0 while True: carry, outputs = model(carry=carry, batch=batch) logits = outputs["logits"] q_halt = outputs["q_halt_logits"] steps += 1 if bool(carry.halted.all()): break exact, token_acc = correctness(logits, batch["labels"]) return exact, token_acc, q_halt, steps def correctness(logits: torch.Tensor, labels: torch.Tensor): preds = logits.argmax(dim=-1) mask = labels != IGNORE_LABEL_ID exact = torch.where(mask, preds == labels, True).all(dim=-1) denom = mask.sum(-1).clamp_min(1) token_acc = ((preds == labels) & mask).sum(-1).float() / denom.float() return exact, token_acc @torch.inference_mode() def ptrm_rollouts( model, batch: dict[str, torch.Tensor], rollouts: int, steps: int, noise_std: float, include_clean: bool, perturb: str, fd_lyap: bool, fd_spectrum_k: int, fd_eps: float, generator: torch.Generator, ): device = batch["inputs"].device base_batch_size = batch["inputs"].shape[0] expanded = repeat_batch(batch, rollouts) total = expanded["inputs"].shape[0] rollout_id = torch.arange(total, device=device) % rollouts with torch.device(device): carry = model.initial_carry(expanded) reset = torch.ones_like(carry.halted) main = model.inner.reset_carry(reset, carry.inner_carry) shadow = None lyap_sum = None spec_shadows = None spec_h_dirs = None spec_l_dirs = None lyap_spec_sum = None if fd_spectrum_k > 0: spec_h_dirs, spec_l_dirs = _rand_orthonormal_dirs_like(main, fd_spectrum_k, generator) spec_shadows = _make_spectrum_shadows(main, spec_h_dirs, spec_l_dirs, fd_eps) lyap_spec_sum = torch.zeros(total, fd_spectrum_k, device=device, dtype=torch.float32) elif fd_lyap: dh, dl = _rand_unit_like(main, generator) shadow = TinyRecursiveReasoningModel_ACTV1InnerCarry( z_H=(main.z_H + fd_eps * dh).detach(), z_L=(main.z_L + fd_eps * dl).detach(), ) lyap_sum = torch.zeros(total, device=device, dtype=torch.float32) logits = None q_halt = None q_continue = None for _ in range(steps): noise_h = _sample_noise(main.z_H.shape, noise_std, generator, main.z_H.dtype, device) noise_l = _sample_noise(main.z_L.shape, noise_std, generator, main.z_L.dtype, device) if include_clean and rollouts > 1: clean_mask = (rollout_id == 0).view((-1,) + (1,) * (main.z_H.ndim - 1)) noise_h = torch.where(clean_mask, torch.zeros_like(noise_h), noise_h) noise_l = torch.where(clean_mask, torch.zeros_like(noise_l), noise_l) main = _apply_step_noise(main, noise_h, noise_l, perturb) if fd_spectrum_k > 0: assert spec_shadows is not None and lyap_spec_sum is not None shadow_noise_h = noise_h.repeat_interleave(fd_spectrum_k, dim=0) shadow_noise_l = noise_l.repeat_interleave(fd_spectrum_k, dim=0) spec_shadows = _apply_step_noise(spec_shadows, shadow_noise_h, shadow_noise_l, perturb) combined_inner = _cat_inner(main, spec_shadows) combined_batch = _cat_many_batches([expanded, _repeat_inner_batch(expanded, fd_spectrum_k)]) combined_inner, combined_logits, (combined_q_halt, combined_q_continue) = model.inner(combined_inner, combined_batch) main, spec_shadows = _split_spectrum_inner(combined_inner, total) logits = combined_logits[:total] q_halt = combined_q_halt[:total] q_continue = combined_q_continue[:total] delta_h = ( spec_shadows.z_H.reshape((total, fd_spectrum_k) + tuple(main.z_H.shape[1:])).float() - main.z_H[:, None].float() ) / fd_eps delta_l = ( spec_shadows.z_L.reshape((total, fd_spectrum_k) + tuple(main.z_L.shape[1:])).float() - main.z_L[:, None].float() ) / fd_eps q, r = torch.linalg.qr(_dirs_to_q(delta_h, delta_l), mode="reduced") diag = torch.diagonal(r, dim1=-2, dim2=-1).abs().clamp_min(1e-30) lyap_spec_sum = lyap_spec_sum + torch.log(diag).float() spec_h_dirs, spec_l_dirs = _q_to_dirs( q, main.z_H.shape[1:], main.z_L.shape[1:], main.z_H.dtype, main.z_L.dtype ) spec_shadows = _make_spectrum_shadows(main, spec_h_dirs, spec_l_dirs, fd_eps) elif fd_lyap: assert shadow is not None shadow = _apply_step_noise(shadow, noise_h, noise_l, perturb) combined_inner = _cat_inner(main, shadow) combined_batch = cat_batches(expanded, expanded) combined_inner, combined_logits, (combined_q_halt, combined_q_continue) = model.inner(combined_inner, combined_batch) main, shadow = _split_inner(combined_inner, total) logits = combined_logits[:total] q_halt = combined_q_halt[:total] q_continue = combined_q_continue[:total] sep = _separation(main, shadow) lyap_sum = lyap_sum + torch.log(sep / fd_eps).float() # type: ignore[operator] shadow = _renormalize_shadow(main, shadow, fd_eps) else: main, logits, (q_halt, q_continue) = model.inner(main, expanded) assert logits is not None and q_halt is not None exact, token_acc = correctness(logits, expanded["labels"]) exact = exact.view(base_batch_size, rollouts) token_acc = token_acc.view(base_batch_size, rollouts) q_halt = q_halt.float().view(base_batch_size, rollouts) q_continue = q_continue.float().view(base_batch_size, rollouts) if q_continue is not None else torch.zeros_like(q_halt) lyap = None lyap_spec = None if fd_spectrum_k > 0: assert lyap_spec_sum is not None lyap_spec = (lyap_spec_sum / max(steps, 1)).view(base_batch_size, rollouts, fd_spectrum_k) lyap_spec = torch.sort(lyap_spec, dim=-1, descending=True).values lyap = lyap_spec[..., 0] elif fd_lyap: assert lyap_sum is not None lyap = (lyap_sum / max(steps, 1)).view(base_batch_size, rollouts) return exact, token_acc, q_halt, q_continue, lyap, lyap_spec def _take_by_idx(values: torch.Tensor, idx: torch.Tensor): return values.gather(1, idx[:, None]).squeeze(1) def _zscore_per_row(values: torch.Tensor): return (values - values.mean(dim=1, keepdim=True)) / values.std(dim=1, keepdim=True).clamp_min(1e-6) def summarize_selectors(exact, token_acc, q_halt, lyap, lyap_spec=None): out: dict[str, float] = {} bsz, rollouts = exact.shape arange = torch.arange(bsz, device=exact.device) correct_counts = exact.float().sum(dim=1) selectors = { "rollout0": torch.zeros(bsz, device=exact.device, dtype=torch.long), "q_max": q_halt.argmax(dim=1), "oracle_pass": None, } if lyap is not None: selectors["lyap_min"] = lyap.argmin(dim=1) qz = _zscore_per_row(q_halt) lz = _zscore_per_row(lyap) for alpha in (0.25, 0.5, 1.0, 2.0): selectors[f"q_minus_{alpha:g}lambda"] = (qz - alpha * lz).argmax(dim=1) if lyap_spec is not None: spec_pos_mass = lyap_spec.clamp_min(0).sum(dim=-1) spec_pos_l2 = lyap_spec.clamp_min(0).square().mean(dim=-1).sqrt() spec_mean = lyap_spec.mean(dim=-1) spec_count_pos = (lyap_spec > 0).float().sum(dim=-1) spec_spread = lyap_spec[..., 0] - lyap_spec[..., -1] selectors["spec_pos_mass_min"] = spec_pos_mass.argmin(dim=1) selectors["spec_pos_l2_min"] = spec_pos_l2.argmin(dim=1) selectors["spec_mean_min"] = spec_mean.argmin(dim=1) selectors["spec_count_pos_min"] = spec_count_pos.argmin(dim=1) selectors["spec_spread_min"] = spec_spread.argmin(dim=1) for name, idx in selectors.items(): if idx is None: out[f"{name}/exact"] = exact.any(dim=1).float().mean().item() out[f"{name}/token_acc"] = token_acc.max(dim=1).values.mean().item() else: out[f"{name}/exact"] = exact[arange, idx].float().mean().item() out[f"{name}/token_acc"] = token_acc[arange, idx].mean().item() out["mean_rollout/exact"] = exact.float().mean().item() out["mean_rollout/token_acc"] = token_acc.mean().item() out["correct_count/mean"] = correct_counts.mean().item() out["correct_count/std"] = correct_counts.std(unbiased=False).item() out["correct_count/median"] = correct_counts.median().item() out["correct_count/q10"] = torch.quantile(correct_counts, 0.10).item() out["correct_count/q25"] = torch.quantile(correct_counts, 0.25).item() out["correct_count/q75"] = torch.quantile(correct_counts, 0.75).item() out["correct_count/q90"] = torch.quantile(correct_counts, 0.90).item() out["correct_count/zero_frac"] = (correct_counts == 0).float().mean().item() out["correct_count/full_frac"] = (correct_counts == rollouts).float().mean().item() for threshold in (1, 5, 10, 25, 50, 75, 90): if threshold <= rollouts: out[f"correct_count/ge_{threshold}_frac"] = (correct_counts >= threshold).float().mean().item() out["q_mean"] = q_halt.mean().item() if lyap is not None: out["lambda_mean"] = lyap.mean().item() if exact.any().item() and (~exact).any().item(): out["lambda_success_mean"] = lyap[exact].mean().item() out["lambda_fail_mean"] = lyap[~exact].mean().item() out["q_success_mean"] = q_halt[exact].mean().item() out["q_fail_mean"] = q_halt[~exact].mean().item() if lyap_spec is not None: spec_pos_mass = lyap_spec.clamp_min(0).sum(dim=-1) spec_pos_l2 = lyap_spec.clamp_min(0).square().mean(dim=-1).sqrt() spec_mean = lyap_spec.mean(dim=-1) spec_count_pos = (lyap_spec > 0).float().sum(dim=-1) spec_spread = lyap_spec[..., 0] - lyap_spec[..., -1] out["spec_k"] = float(lyap_spec.shape[-1]) out["spec_pos_mass_mean"] = spec_pos_mass.mean().item() out["spec_pos_l2_mean"] = spec_pos_l2.mean().item() out["spec_mean_mean"] = spec_mean.mean().item() out["spec_count_pos_mean"] = spec_count_pos.mean().item() out["spec_spread_mean"] = spec_spread.mean().item() if exact.any().item() and (~exact).any().item(): out["spec_pos_mass_success_mean"] = spec_pos_mass[exact].mean().item() out["spec_pos_mass_fail_mean"] = spec_pos_mass[~exact].mean().item() out["spec_mean_success_mean"] = spec_mean[exact].mean().item() out["spec_mean_fail_mean"] = spec_mean[~exact].mean().item() out["spec_count_pos_success_mean"] = spec_count_pos[exact].mean().item() out["spec_count_pos_fail_mean"] = spec_count_pos[~exact].mean().item() return out def main(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt-root", required=True) parser.add_argument("--ckpt-name", default="step_260410") parser.add_argument("--n-samples", type=int, default=512) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--rollouts", type=int, default=8) parser.add_argument("--steps", type=int, default=16) parser.add_argument("--noise-std", type=float, default=1e-3) parser.add_argument("--include-clean", action="store_true") parser.add_argument("--perturb", choices=["h", "l", "both"], default="both") parser.add_argument("--fd-lyap", action="store_true") parser.add_argument("--fd-spectrum-k", type=int, default=0) parser.add_argument("--fd-eps", type=float, default=1e-2) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--out-prefix", default="research/flossing/ptrm_selection") args = parser.parse_args() device = "cuda" torch.manual_seed(args.seed) generator = torch.Generator(device=device).manual_seed(args.seed + 12345) ckpt_root = Path(args.ckpt_root) model, cfg = load_model(ckpt_root, args.ckpt_name, device) samples = load_test_samples(Path(cfg["data_paths"][0]), args.n_samples, args.seed) n = len(samples["inputs"]) all_det_exact, all_det_token = [], [] all_exact, all_token, all_q, all_q_continue, all_lam, all_spec = [], [], [], [], [], [] for start in range(0, n, args.batch_size): end = min(start + args.batch_size, n) batch = batch_slice(samples, start, end, device) det_exact, det_token, _det_q, det_steps = deterministic_eval(model, batch) exact, token_acc, q_halt, q_continue, lyap, lyap_spec = ptrm_rollouts( model=model, batch=batch, rollouts=args.rollouts, steps=args.steps, noise_std=args.noise_std, include_clean=args.include_clean, perturb=args.perturb, fd_lyap=args.fd_lyap, fd_spectrum_k=args.fd_spectrum_k, fd_eps=args.fd_eps, generator=generator, ) all_det_exact.append(det_exact.cpu()) all_det_token.append(det_token.cpu()) all_exact.append(exact.cpu()) all_token.append(token_acc.cpu()) all_q.append(q_halt.cpu()) all_q_continue.append(q_continue.cpu()) if lyap is not None: all_lam.append(lyap.cpu()) if lyap_spec is not None: all_spec.append(lyap_spec.cpu()) print( f"[{end}/{n}] det={det_exact.float().mean().item():.4f} " f"q_sel={_take_by_idx(exact, q_halt.argmax(1)).float().mean().item():.4f} " f"pass@K={exact.any(1).float().mean().item():.4f} steps={det_steps}", flush=True, ) det_exact = torch.cat(all_det_exact) det_token = torch.cat(all_det_token) exact = torch.cat(all_exact) token_acc = torch.cat(all_token) q_halt = torch.cat(all_q) q_continue = torch.cat(all_q_continue) lyap = torch.cat(all_lam) if all_lam else None lyap_spec = torch.cat(all_spec) if all_spec else None summary = summarize_selectors(exact, token_acc, q_halt, lyap, lyap_spec) summary["deterministic/exact"] = det_exact.float().mean().item() summary["deterministic/token_acc"] = det_token.mean().item() correct_counts = exact.float().sum(dim=1) oracle_success = exact.any(dim=1) q_selected = exact[torch.arange(exact.shape[0]), q_halt.argmax(dim=1)] det_success = det_exact.bool() det_fail = ~det_success if det_success.any().item(): summary["correct_count/det_success_mean"] = correct_counts[det_success].mean().item() summary["oracle_pass/det_success_frac"] = oracle_success[det_success].float().mean().item() summary["q_max/det_success_frac"] = q_selected[det_success].float().mean().item() if det_fail.any().item(): summary["correct_count/det_fail_mean"] = correct_counts[det_fail].mean().item() summary["oracle_pass/det_fail_frac"] = oracle_success[det_fail].float().mean().item() summary["q_max/det_fail_frac"] = q_selected[det_fail].float().mean().item() summary["n_samples"] = float(n) summary["rollouts"] = float(args.rollouts) summary["noise_std"] = float(args.noise_std) summary["include_clean"] = float(args.include_clean) summary["fd_lyap"] = float(args.fd_lyap) summary["fd_spectrum_k"] = float(args.fd_spectrum_k) summary["steps"] = float(args.steps) summary["perturb_l"] = float(args.perturb == "l") summary["perturb_h"] = float(args.perturb == "h") summary["perturb_both"] = float(args.perturb == "both") out_prefix = Path(args.out_prefix) out_prefix.parent.mkdir(parents=True, exist_ok=True) meta = { "ckpt_root": str(ckpt_root), "ckpt_name": args.ckpt_name, "n_samples": n, "batch_size": args.batch_size, "rollouts": args.rollouts, "steps": args.steps, "noise_std": args.noise_std, "include_clean": args.include_clean, "perturb": args.perturb, "fd_lyap": args.fd_lyap, "fd_spectrum_k": args.fd_spectrum_k, "fd_eps": args.fd_eps, "seed": args.seed, } np.savez_compressed( f"{out_prefix}.npz", idx=samples["idx"], det_exact=det_exact.numpy(), det_token_acc=det_token.numpy(), exact=exact.numpy(), token_acc=token_acc.numpy(), q_halt=q_halt.numpy(), q_continue=q_continue.numpy(), lyap=np.asarray([]) if lyap is None else lyap.numpy(), lyap_spec=np.asarray([]) if lyap_spec is None else lyap_spec.numpy(), meta_json=np.asarray(json.dumps(meta, sort_keys=True)), ) with open(f"{out_prefix}.meta.json", "w") as f: json.dump(meta, f, indent=2, sort_keys=True) with open(f"{out_prefix}.summary.csv", "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=sorted(summary)) writer.writeheader() writer.writerow(summary) print("\nsummary") for key in sorted(summary): print(f"{key}: {summary[key]}") print(f"\nsaved {out_prefix}.npz and {out_prefix}.summary.csv") if __name__ == "__main__": main()