summaryrefslogtreecommitdiff
path: root/research/flossing/ptrm_rollout_selection.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/ptrm_rollout_selection.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/ptrm_rollout_selection.py')
-rw-r--r--research/flossing/ptrm_rollout_selection.py643
1 files changed, 643 insertions, 0 deletions
diff --git a/research/flossing/ptrm_rollout_selection.py b/research/flossing/ptrm_rollout_selection.py
new file mode 100644
index 0000000..4fb1dd9
--- /dev/null
+++ b/research/flossing/ptrm_rollout_selection.py
@@ -0,0 +1,643 @@
+"""PTRM-style stochastic rollout evaluation with Q and stability selection.
+
+This is an inference-time experiment: no training, no weight updates.
+
+For each input, run K stochastic recursive trajectories by injecting Gaussian
+noise into the latent state before every ACT step. Select a trajectory by:
+ - Q head score (PTRM)
+ - finite-difference top Lyapunov proxy (lowest lambda)
+ - finite-difference low-rank Lyapunov spectrum proxies
+ - simple Q/lambda hybrid scores
+
+The Lyapunov proxy is computed by pairing each rollout with a tiny shadow
+trajectory that receives the same stochastic noise and is renormalized after
+each ACT step. This is much cheaper than JVP-based exact spectrum estimation
+and is enough to test whether stability can act as a free selector.
+
+The optional spectrum proxy generalizes the shadow trajectory to k orthogonal
+shadows and uses QR re-orthogonalization after every ACT step. This estimates
+the top-k finite-time spectrum in a random tangent subspace. It is much more
+expensive than top-1 because the model batch is multiplied by k + 1.
+"""
+from __future__ import annotations
+
+import argparse
+import csv
+import json
+import math
+import sys
+from dataclasses import replace
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+
+TRM_DIR = Path("/home/yurenh2/rrm/trm")
+sys.path.insert(0, str(TRM_DIR))
+
+from models.recursive_reasoning.trm import ( # noqa: E402
+ TinyRecursiveReasoningModel_ACTV1,
+ TinyRecursiveReasoningModel_ACTV1InnerCarry,
+)
+
+
+IGNORE_LABEL_ID = -100
+
+
+def load_model(ckpt_root: Path, ckpt_name: str, device: str):
+ cfg = json.loads(json.dumps(__import__("yaml").safe_load((ckpt_root / "all_config.yaml").read_text())))
+ train_meta = json.loads((Path(cfg["data_paths"][0]) / "train" / "dataset.json").read_text())
+
+ arch_cfg = dict(cfg["arch"])
+ arch_cfg.update(
+ batch_size=cfg["global_batch_size"],
+ seq_len=train_meta["seq_len"],
+ vocab_size=train_meta["vocab_size"],
+ num_puzzle_identifiers=train_meta["num_puzzle_identifiers"],
+ )
+
+ model = TinyRecursiveReasoningModel_ACTV1(arch_cfg)
+ state = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True)
+ stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in state.items()}
+ missing, unexpected = model.load_state_dict(stripped, strict=False)
+ print(f"[load] {ckpt_root.name}/{ckpt_name} missing={len(missing)} unexpected={len(unexpected)}")
+ if missing[:3]:
+ print(f"[load] sample missing: {missing[:3]}")
+ if unexpected[:3]:
+ print(f"[load] sample unexpected: {unexpected[:3]}")
+ model.to(device).eval()
+ return model, cfg
+
+
+def load_test_samples(data_path: Path, n_samples: int, seed: int):
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "test" / "all__inputs.npy")
+ labels = np.load(data_path / "test" / "all__labels.npy")
+ puzzle_ids = np.load(data_path / "test" / "all__puzzle_identifiers.npy")
+
+ n = min(n_samples, len(inputs))
+ idx = rng.choice(len(inputs), size=n, replace=False)
+ return {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)),
+ "puzzle_identifiers": torch.from_numpy(puzzle_ids[idx].astype(np.int32)),
+ "idx": idx,
+ }
+
+
+def batch_slice(samples: dict[str, Any], start: int, end: int, device: str):
+ return {
+ k: v[start:end].to(device, non_blocking=True)
+ for k, v in samples.items()
+ if k in ("inputs", "labels", "puzzle_identifiers")
+ }
+
+
+def repeat_batch(batch: dict[str, torch.Tensor], repeats: int):
+ if repeats == 1:
+ return batch
+ return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()}
+
+
+def cat_batches(a: dict[str, torch.Tensor], b: dict[str, torch.Tensor]):
+ return {k: torch.cat([a[k], b[k]], dim=0) for k in a}
+
+
+def _rand_unit_like(inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, generator: torch.Generator):
+ dh = torch.randn(inner.z_H.shape, device=inner.z_H.device, dtype=torch.float32, generator=generator)
+ dl = torch.randn(inner.z_L.shape, device=inner.z_L.device, dtype=torch.float32, generator=generator)
+ norm = torch.sqrt(dh.flatten(1).square().sum(-1) + dl.flatten(1).square().sum(-1)).clamp_min(1e-30)
+ view_h = (dh.shape[0],) + (1,) * (dh.ndim - 1)
+ view_l = (dl.shape[0],) + (1,) * (dl.ndim - 1)
+ return (dh / norm.view(view_h)).to(inner.z_H.dtype), (dl / norm.view(view_l)).to(inner.z_L.dtype)
+
+
+def _q_to_dirs(
+ q: torch.Tensor,
+ z_h_shape: torch.Size,
+ z_l_shape: torch.Size,
+ h_dtype: torch.dtype,
+ l_dtype: torch.dtype,
+):
+ total, _dim, spec_k = q.shape
+ h_numel = math.prod(z_h_shape)
+ q_t = q.transpose(1, 2).contiguous()
+ h_dirs = q_t[:, :, :h_numel].reshape((total, spec_k) + tuple(z_h_shape)).to(h_dtype)
+ l_dirs = q_t[:, :, h_numel:].reshape((total, spec_k) + tuple(z_l_shape)).to(l_dtype)
+ return h_dirs, l_dirs
+
+
+def _dirs_to_q(h_dirs: torch.Tensor, l_dirs: torch.Tensor):
+ q_t = torch.cat([h_dirs.float().flatten(2), l_dirs.float().flatten(2)], dim=2)
+ return q_t.transpose(1, 2).contiguous()
+
+
+def _rand_orthonormal_dirs_like(
+ inner: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ spec_k: int,
+ generator: torch.Generator,
+):
+ total = inner.z_H.shape[0]
+ h_dirs = torch.randn(
+ (total, spec_k) + tuple(inner.z_H.shape[1:]),
+ device=inner.z_H.device,
+ dtype=torch.float32,
+ generator=generator,
+ )
+ l_dirs = torch.randn(
+ (total, spec_k) + tuple(inner.z_L.shape[1:]),
+ device=inner.z_L.device,
+ dtype=torch.float32,
+ generator=generator,
+ )
+ q, _ = torch.linalg.qr(_dirs_to_q(h_dirs, l_dirs), mode="reduced")
+ return _q_to_dirs(q, inner.z_H.shape[1:], inner.z_L.shape[1:], inner.z_H.dtype, inner.z_L.dtype)
+
+
+def _make_spectrum_shadows(
+ main: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ h_dirs: torch.Tensor,
+ l_dirs: torch.Tensor,
+ eps: float,
+):
+ total, spec_k = h_dirs.shape[:2]
+ z_h = main.z_H[:, None] + eps * h_dirs.to(main.z_H.dtype)
+ z_l = main.z_L[:, None] + eps * l_dirs.to(main.z_L.dtype)
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
+ z_H=z_h.reshape((total * spec_k,) + tuple(main.z_H.shape[1:])).detach(),
+ z_L=z_l.reshape((total * spec_k,) + tuple(main.z_L.shape[1:])).detach(),
+ )
+
+
+def _repeat_inner_batch(batch: dict[str, torch.Tensor], repeats: int):
+ return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()}
+
+
+def _cat_many_batches(batches: list[dict[str, torch.Tensor]]):
+ return {k: torch.cat([b[k] for b in batches], dim=0) for k in batches[0]}
+
+
+def _split_inner(inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, n: int):
+ return (
+ TinyRecursiveReasoningModel_ACTV1InnerCarry(
+ z_H=inner.z_H[:n].contiguous(),
+ z_L=inner.z_L[:n].contiguous(),
+ ),
+ TinyRecursiveReasoningModel_ACTV1InnerCarry(
+ z_H=inner.z_H[n:].contiguous(),
+ z_L=inner.z_L[n:].contiguous(),
+ ),
+ )
+
+
+def _split_spectrum_inner(inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, n_main: int):
+ return (
+ TinyRecursiveReasoningModel_ACTV1InnerCarry(
+ z_H=inner.z_H[:n_main].contiguous(),
+ z_L=inner.z_L[:n_main].contiguous(),
+ ),
+ TinyRecursiveReasoningModel_ACTV1InnerCarry(
+ z_H=inner.z_H[n_main:].contiguous(),
+ z_L=inner.z_L[n_main:].contiguous(),
+ ),
+ )
+
+
+def _cat_inner(a: TinyRecursiveReasoningModel_ACTV1InnerCarry, b: TinyRecursiveReasoningModel_ACTV1InnerCarry):
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
+ z_H=torch.cat([a.z_H, b.z_H], dim=0),
+ z_L=torch.cat([a.z_L, b.z_L], dim=0),
+ )
+
+
+def _sample_noise(
+ shape: torch.Size,
+ std: float,
+ generator: torch.Generator,
+ dtype: torch.dtype,
+ device: torch.device,
+):
+ if std <= 0:
+ return torch.zeros(shape, device=device, dtype=dtype)
+ return (std * torch.randn(shape, device=device, dtype=torch.float32, generator=generator)).to(dtype)
+
+
+def _apply_step_noise(
+ inner: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ noise_h: torch.Tensor,
+ noise_l: torch.Tensor,
+ perturb: str,
+):
+ z_h, z_l = inner.z_H, inner.z_L
+ if perturb in ("h", "both"):
+ z_h = z_h + noise_h
+ if perturb in ("l", "both"):
+ z_l = z_l + noise_l
+ return replace(inner, z_H=z_h, z_L=z_l)
+
+
+def _separation(
+ main: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ shadow: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+):
+ dh = (shadow.z_H.float() - main.z_H.float()).flatten(1)
+ dl = (shadow.z_L.float() - main.z_L.float()).flatten(1)
+ return torch.sqrt(dh.square().sum(-1) + dl.square().sum(-1)).clamp_min(1e-30)
+
+
+def _renormalize_shadow(
+ main: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ shadow: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ eps: float,
+):
+ sep = _separation(main, shadow)
+ view_h = (sep.shape[0],) + (1,) * (main.z_H.ndim - 1)
+ view_l = (sep.shape[0],) + (1,) * (main.z_L.ndim - 1)
+ scale_h = (eps / sep).view(view_h).to(main.z_H.dtype)
+ scale_l = (eps / sep).view(view_l).to(main.z_L.dtype)
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
+ z_H=(main.z_H + (shadow.z_H - main.z_H) * scale_h).detach(),
+ z_L=(main.z_L + (shadow.z_L - main.z_L) * scale_l).detach(),
+ )
+
+
+@torch.inference_mode()
+def deterministic_eval(model, batch: dict[str, torch.Tensor]):
+ with torch.device(batch["inputs"].device):
+ carry = model.initial_carry(batch)
+ logits = None
+ q_halt = None
+ steps = 0
+ while True:
+ carry, outputs = model(carry=carry, batch=batch)
+ logits = outputs["logits"]
+ q_halt = outputs["q_halt_logits"]
+ steps += 1
+ if bool(carry.halted.all()):
+ break
+ exact, token_acc = correctness(logits, batch["labels"])
+ return exact, token_acc, q_halt, steps
+
+
+def correctness(logits: torch.Tensor, labels: torch.Tensor):
+ preds = logits.argmax(dim=-1)
+ mask = labels != IGNORE_LABEL_ID
+ exact = torch.where(mask, preds == labels, True).all(dim=-1)
+ denom = mask.sum(-1).clamp_min(1)
+ token_acc = ((preds == labels) & mask).sum(-1).float() / denom.float()
+ return exact, token_acc
+
+
+@torch.inference_mode()
+def ptrm_rollouts(
+ model,
+ batch: dict[str, torch.Tensor],
+ rollouts: int,
+ steps: int,
+ noise_std: float,
+ include_clean: bool,
+ perturb: str,
+ fd_lyap: bool,
+ fd_spectrum_k: int,
+ fd_eps: float,
+ generator: torch.Generator,
+):
+ device = batch["inputs"].device
+ base_batch_size = batch["inputs"].shape[0]
+ expanded = repeat_batch(batch, rollouts)
+ total = expanded["inputs"].shape[0]
+ rollout_id = torch.arange(total, device=device) % rollouts
+
+ with torch.device(device):
+ carry = model.initial_carry(expanded)
+ reset = torch.ones_like(carry.halted)
+ main = model.inner.reset_carry(reset, carry.inner_carry)
+
+ shadow = None
+ lyap_sum = None
+ spec_shadows = None
+ spec_h_dirs = None
+ spec_l_dirs = None
+ lyap_spec_sum = None
+ if fd_spectrum_k > 0:
+ spec_h_dirs, spec_l_dirs = _rand_orthonormal_dirs_like(main, fd_spectrum_k, generator)
+ spec_shadows = _make_spectrum_shadows(main, spec_h_dirs, spec_l_dirs, fd_eps)
+ lyap_spec_sum = torch.zeros(total, fd_spectrum_k, device=device, dtype=torch.float32)
+ elif fd_lyap:
+ dh, dl = _rand_unit_like(main, generator)
+ shadow = TinyRecursiveReasoningModel_ACTV1InnerCarry(
+ z_H=(main.z_H + fd_eps * dh).detach(),
+ z_L=(main.z_L + fd_eps * dl).detach(),
+ )
+ lyap_sum = torch.zeros(total, device=device, dtype=torch.float32)
+
+ logits = None
+ q_halt = None
+ q_continue = None
+ for _ in range(steps):
+ noise_h = _sample_noise(main.z_H.shape, noise_std, generator, main.z_H.dtype, device)
+ noise_l = _sample_noise(main.z_L.shape, noise_std, generator, main.z_L.dtype, device)
+ if include_clean and rollouts > 1:
+ clean_mask = (rollout_id == 0).view((-1,) + (1,) * (main.z_H.ndim - 1))
+ noise_h = torch.where(clean_mask, torch.zeros_like(noise_h), noise_h)
+ noise_l = torch.where(clean_mask, torch.zeros_like(noise_l), noise_l)
+
+ main = _apply_step_noise(main, noise_h, noise_l, perturb)
+ if fd_spectrum_k > 0:
+ assert spec_shadows is not None and lyap_spec_sum is not None
+ shadow_noise_h = noise_h.repeat_interleave(fd_spectrum_k, dim=0)
+ shadow_noise_l = noise_l.repeat_interleave(fd_spectrum_k, dim=0)
+ spec_shadows = _apply_step_noise(spec_shadows, shadow_noise_h, shadow_noise_l, perturb)
+ combined_inner = _cat_inner(main, spec_shadows)
+ combined_batch = _cat_many_batches([expanded, _repeat_inner_batch(expanded, fd_spectrum_k)])
+ combined_inner, combined_logits, (combined_q_halt, combined_q_continue) = model.inner(combined_inner, combined_batch)
+ main, spec_shadows = _split_spectrum_inner(combined_inner, total)
+ logits = combined_logits[:total]
+ q_halt = combined_q_halt[:total]
+ q_continue = combined_q_continue[:total]
+
+ delta_h = (
+ spec_shadows.z_H.reshape((total, fd_spectrum_k) + tuple(main.z_H.shape[1:])).float()
+ - main.z_H[:, None].float()
+ ) / fd_eps
+ delta_l = (
+ spec_shadows.z_L.reshape((total, fd_spectrum_k) + tuple(main.z_L.shape[1:])).float()
+ - main.z_L[:, None].float()
+ ) / fd_eps
+ q, r = torch.linalg.qr(_dirs_to_q(delta_h, delta_l), mode="reduced")
+ diag = torch.diagonal(r, dim1=-2, dim2=-1).abs().clamp_min(1e-30)
+ lyap_spec_sum = lyap_spec_sum + torch.log(diag).float()
+ spec_h_dirs, spec_l_dirs = _q_to_dirs(
+ q, main.z_H.shape[1:], main.z_L.shape[1:], main.z_H.dtype, main.z_L.dtype
+ )
+ spec_shadows = _make_spectrum_shadows(main, spec_h_dirs, spec_l_dirs, fd_eps)
+ elif fd_lyap:
+ assert shadow is not None
+ shadow = _apply_step_noise(shadow, noise_h, noise_l, perturb)
+ combined_inner = _cat_inner(main, shadow)
+ combined_batch = cat_batches(expanded, expanded)
+ combined_inner, combined_logits, (combined_q_halt, combined_q_continue) = model.inner(combined_inner, combined_batch)
+ main, shadow = _split_inner(combined_inner, total)
+ logits = combined_logits[:total]
+ q_halt = combined_q_halt[:total]
+ q_continue = combined_q_continue[:total]
+ sep = _separation(main, shadow)
+ lyap_sum = lyap_sum + torch.log(sep / fd_eps).float() # type: ignore[operator]
+ shadow = _renormalize_shadow(main, shadow, fd_eps)
+ else:
+ main, logits, (q_halt, q_continue) = model.inner(main, expanded)
+
+ assert logits is not None and q_halt is not None
+ exact, token_acc = correctness(logits, expanded["labels"])
+ exact = exact.view(base_batch_size, rollouts)
+ token_acc = token_acc.view(base_batch_size, rollouts)
+ q_halt = q_halt.float().view(base_batch_size, rollouts)
+ q_continue = q_continue.float().view(base_batch_size, rollouts) if q_continue is not None else torch.zeros_like(q_halt)
+ lyap = None
+ lyap_spec = None
+ if fd_spectrum_k > 0:
+ assert lyap_spec_sum is not None
+ lyap_spec = (lyap_spec_sum / max(steps, 1)).view(base_batch_size, rollouts, fd_spectrum_k)
+ lyap_spec = torch.sort(lyap_spec, dim=-1, descending=True).values
+ lyap = lyap_spec[..., 0]
+ elif fd_lyap:
+ assert lyap_sum is not None
+ lyap = (lyap_sum / max(steps, 1)).view(base_batch_size, rollouts)
+ return exact, token_acc, q_halt, q_continue, lyap, lyap_spec
+
+
+def _take_by_idx(values: torch.Tensor, idx: torch.Tensor):
+ return values.gather(1, idx[:, None]).squeeze(1)
+
+
+def _zscore_per_row(values: torch.Tensor):
+ return (values - values.mean(dim=1, keepdim=True)) / values.std(dim=1, keepdim=True).clamp_min(1e-6)
+
+
+def summarize_selectors(exact, token_acc, q_halt, lyap, lyap_spec=None):
+ out: dict[str, float] = {}
+ bsz, rollouts = exact.shape
+ arange = torch.arange(bsz, device=exact.device)
+ correct_counts = exact.float().sum(dim=1)
+
+ selectors = {
+ "rollout0": torch.zeros(bsz, device=exact.device, dtype=torch.long),
+ "q_max": q_halt.argmax(dim=1),
+ "oracle_pass": None,
+ }
+ if lyap is not None:
+ selectors["lyap_min"] = lyap.argmin(dim=1)
+ qz = _zscore_per_row(q_halt)
+ lz = _zscore_per_row(lyap)
+ for alpha in (0.25, 0.5, 1.0, 2.0):
+ selectors[f"q_minus_{alpha:g}lambda"] = (qz - alpha * lz).argmax(dim=1)
+ if lyap_spec is not None:
+ spec_pos_mass = lyap_spec.clamp_min(0).sum(dim=-1)
+ spec_pos_l2 = lyap_spec.clamp_min(0).square().mean(dim=-1).sqrt()
+ spec_mean = lyap_spec.mean(dim=-1)
+ spec_count_pos = (lyap_spec > 0).float().sum(dim=-1)
+ spec_spread = lyap_spec[..., 0] - lyap_spec[..., -1]
+
+ selectors["spec_pos_mass_min"] = spec_pos_mass.argmin(dim=1)
+ selectors["spec_pos_l2_min"] = spec_pos_l2.argmin(dim=1)
+ selectors["spec_mean_min"] = spec_mean.argmin(dim=1)
+ selectors["spec_count_pos_min"] = spec_count_pos.argmin(dim=1)
+ selectors["spec_spread_min"] = spec_spread.argmin(dim=1)
+
+ for name, idx in selectors.items():
+ if idx is None:
+ out[f"{name}/exact"] = exact.any(dim=1).float().mean().item()
+ out[f"{name}/token_acc"] = token_acc.max(dim=1).values.mean().item()
+ else:
+ out[f"{name}/exact"] = exact[arange, idx].float().mean().item()
+ out[f"{name}/token_acc"] = token_acc[arange, idx].mean().item()
+
+ out["mean_rollout/exact"] = exact.float().mean().item()
+ out["mean_rollout/token_acc"] = token_acc.mean().item()
+ out["correct_count/mean"] = correct_counts.mean().item()
+ out["correct_count/std"] = correct_counts.std(unbiased=False).item()
+ out["correct_count/median"] = correct_counts.median().item()
+ out["correct_count/q10"] = torch.quantile(correct_counts, 0.10).item()
+ out["correct_count/q25"] = torch.quantile(correct_counts, 0.25).item()
+ out["correct_count/q75"] = torch.quantile(correct_counts, 0.75).item()
+ out["correct_count/q90"] = torch.quantile(correct_counts, 0.90).item()
+ out["correct_count/zero_frac"] = (correct_counts == 0).float().mean().item()
+ out["correct_count/full_frac"] = (correct_counts == rollouts).float().mean().item()
+ for threshold in (1, 5, 10, 25, 50, 75, 90):
+ if threshold <= rollouts:
+ out[f"correct_count/ge_{threshold}_frac"] = (correct_counts >= threshold).float().mean().item()
+ out["q_mean"] = q_halt.mean().item()
+ if lyap is not None:
+ out["lambda_mean"] = lyap.mean().item()
+ if exact.any().item() and (~exact).any().item():
+ out["lambda_success_mean"] = lyap[exact].mean().item()
+ out["lambda_fail_mean"] = lyap[~exact].mean().item()
+ out["q_success_mean"] = q_halt[exact].mean().item()
+ out["q_fail_mean"] = q_halt[~exact].mean().item()
+ if lyap_spec is not None:
+ spec_pos_mass = lyap_spec.clamp_min(0).sum(dim=-1)
+ spec_pos_l2 = lyap_spec.clamp_min(0).square().mean(dim=-1).sqrt()
+ spec_mean = lyap_spec.mean(dim=-1)
+ spec_count_pos = (lyap_spec > 0).float().sum(dim=-1)
+ spec_spread = lyap_spec[..., 0] - lyap_spec[..., -1]
+ out["spec_k"] = float(lyap_spec.shape[-1])
+ out["spec_pos_mass_mean"] = spec_pos_mass.mean().item()
+ out["spec_pos_l2_mean"] = spec_pos_l2.mean().item()
+ out["spec_mean_mean"] = spec_mean.mean().item()
+ out["spec_count_pos_mean"] = spec_count_pos.mean().item()
+ out["spec_spread_mean"] = spec_spread.mean().item()
+ if exact.any().item() and (~exact).any().item():
+ out["spec_pos_mass_success_mean"] = spec_pos_mass[exact].mean().item()
+ out["spec_pos_mass_fail_mean"] = spec_pos_mass[~exact].mean().item()
+ out["spec_mean_success_mean"] = spec_mean[exact].mean().item()
+ out["spec_mean_fail_mean"] = spec_mean[~exact].mean().item()
+ out["spec_count_pos_success_mean"] = spec_count_pos[exact].mean().item()
+ out["spec_count_pos_fail_mean"] = spec_count_pos[~exact].mean().item()
+ return out
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt-root", required=True)
+ parser.add_argument("--ckpt-name", default="step_260410")
+ parser.add_argument("--n-samples", type=int, default=512)
+ parser.add_argument("--batch-size", type=int, default=32)
+ parser.add_argument("--rollouts", type=int, default=8)
+ parser.add_argument("--steps", type=int, default=16)
+ parser.add_argument("--noise-std", type=float, default=1e-3)
+ parser.add_argument("--include-clean", action="store_true")
+ parser.add_argument("--perturb", choices=["h", "l", "both"], default="both")
+ parser.add_argument("--fd-lyap", action="store_true")
+ parser.add_argument("--fd-spectrum-k", type=int, default=0)
+ parser.add_argument("--fd-eps", type=float, default=1e-2)
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--out-prefix", default="research/flossing/ptrm_selection")
+ args = parser.parse_args()
+
+ device = "cuda"
+ torch.manual_seed(args.seed)
+ generator = torch.Generator(device=device).manual_seed(args.seed + 12345)
+
+ ckpt_root = Path(args.ckpt_root)
+ model, cfg = load_model(ckpt_root, args.ckpt_name, device)
+ samples = load_test_samples(Path(cfg["data_paths"][0]), args.n_samples, args.seed)
+ n = len(samples["inputs"])
+
+ all_det_exact, all_det_token = [], []
+ all_exact, all_token, all_q, all_q_continue, all_lam, all_spec = [], [], [], [], [], []
+
+ for start in range(0, n, args.batch_size):
+ end = min(start + args.batch_size, n)
+ batch = batch_slice(samples, start, end, device)
+ det_exact, det_token, _det_q, det_steps = deterministic_eval(model, batch)
+ exact, token_acc, q_halt, q_continue, lyap, lyap_spec = ptrm_rollouts(
+ model=model,
+ batch=batch,
+ rollouts=args.rollouts,
+ steps=args.steps,
+ noise_std=args.noise_std,
+ include_clean=args.include_clean,
+ perturb=args.perturb,
+ fd_lyap=args.fd_lyap,
+ fd_spectrum_k=args.fd_spectrum_k,
+ fd_eps=args.fd_eps,
+ generator=generator,
+ )
+ all_det_exact.append(det_exact.cpu())
+ all_det_token.append(det_token.cpu())
+ all_exact.append(exact.cpu())
+ all_token.append(token_acc.cpu())
+ all_q.append(q_halt.cpu())
+ all_q_continue.append(q_continue.cpu())
+ if lyap is not None:
+ all_lam.append(lyap.cpu())
+ if lyap_spec is not None:
+ all_spec.append(lyap_spec.cpu())
+ print(
+ f"[{end}/{n}] det={det_exact.float().mean().item():.4f} "
+ f"q_sel={_take_by_idx(exact, q_halt.argmax(1)).float().mean().item():.4f} "
+ f"pass@K={exact.any(1).float().mean().item():.4f} steps={det_steps}",
+ flush=True,
+ )
+
+ det_exact = torch.cat(all_det_exact)
+ det_token = torch.cat(all_det_token)
+ exact = torch.cat(all_exact)
+ token_acc = torch.cat(all_token)
+ q_halt = torch.cat(all_q)
+ q_continue = torch.cat(all_q_continue)
+ lyap = torch.cat(all_lam) if all_lam else None
+ lyap_spec = torch.cat(all_spec) if all_spec else None
+ summary = summarize_selectors(exact, token_acc, q_halt, lyap, lyap_spec)
+ summary["deterministic/exact"] = det_exact.float().mean().item()
+ summary["deterministic/token_acc"] = det_token.mean().item()
+ correct_counts = exact.float().sum(dim=1)
+ oracle_success = exact.any(dim=1)
+ q_selected = exact[torch.arange(exact.shape[0]), q_halt.argmax(dim=1)]
+ det_success = det_exact.bool()
+ det_fail = ~det_success
+ if det_success.any().item():
+ summary["correct_count/det_success_mean"] = correct_counts[det_success].mean().item()
+ summary["oracle_pass/det_success_frac"] = oracle_success[det_success].float().mean().item()
+ summary["q_max/det_success_frac"] = q_selected[det_success].float().mean().item()
+ if det_fail.any().item():
+ summary["correct_count/det_fail_mean"] = correct_counts[det_fail].mean().item()
+ summary["oracle_pass/det_fail_frac"] = oracle_success[det_fail].float().mean().item()
+ summary["q_max/det_fail_frac"] = q_selected[det_fail].float().mean().item()
+ summary["n_samples"] = float(n)
+ summary["rollouts"] = float(args.rollouts)
+ summary["noise_std"] = float(args.noise_std)
+ summary["include_clean"] = float(args.include_clean)
+ summary["fd_lyap"] = float(args.fd_lyap)
+ summary["fd_spectrum_k"] = float(args.fd_spectrum_k)
+ summary["steps"] = float(args.steps)
+ summary["perturb_l"] = float(args.perturb == "l")
+ summary["perturb_h"] = float(args.perturb == "h")
+ summary["perturb_both"] = float(args.perturb == "both")
+
+ out_prefix = Path(args.out_prefix)
+ out_prefix.parent.mkdir(parents=True, exist_ok=True)
+ meta = {
+ "ckpt_root": str(ckpt_root),
+ "ckpt_name": args.ckpt_name,
+ "n_samples": n,
+ "batch_size": args.batch_size,
+ "rollouts": args.rollouts,
+ "steps": args.steps,
+ "noise_std": args.noise_std,
+ "include_clean": args.include_clean,
+ "perturb": args.perturb,
+ "fd_lyap": args.fd_lyap,
+ "fd_spectrum_k": args.fd_spectrum_k,
+ "fd_eps": args.fd_eps,
+ "seed": args.seed,
+ }
+ np.savez_compressed(
+ f"{out_prefix}.npz",
+ idx=samples["idx"],
+ det_exact=det_exact.numpy(),
+ det_token_acc=det_token.numpy(),
+ exact=exact.numpy(),
+ token_acc=token_acc.numpy(),
+ q_halt=q_halt.numpy(),
+ q_continue=q_continue.numpy(),
+ lyap=np.asarray([]) if lyap is None else lyap.numpy(),
+ lyap_spec=np.asarray([]) if lyap_spec is None else lyap_spec.numpy(),
+ meta_json=np.asarray(json.dumps(meta, sort_keys=True)),
+ )
+ with open(f"{out_prefix}.meta.json", "w") as f:
+ json.dump(meta, f, indent=2, sort_keys=True)
+ with open(f"{out_prefix}.summary.csv", "w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=sorted(summary))
+ writer.writeheader()
+ writer.writerow(summary)
+
+ print("\nsummary")
+ for key in sorted(summary):
+ print(f"{key}: {summary[key]}")
+ print(f"\nsaved {out_prefix}.npz and {out_prefix}.summary.csv")
+
+
+if __name__ == "__main__":
+ main()