"""Inference-time robustness to initial recurrent-state perturbations. This tests whether trajectory-perturbation training enlarged the correct attractor basin. Unlike PTRM-style rollout noise, the perturbation is applied once after resetting z_H/z_L, matching the training augmentation mechanism. """ from __future__ import annotations import argparse import csv import json import math import sys from dataclasses import replace from pathlib import Path from typing import Any import numpy as np import torch import yaml TRM_DIR = Path("/home/yurenh2/rrm/trm") sys.path.insert(0, str(TRM_DIR)) from models.recursive_reasoning.trm import ( # noqa: E402 TinyRecursiveReasoningModel_ACTV1, TinyRecursiveReasoningModel_ACTV1InnerCarry, ) IGNORE_LABEL_ID = -100 def parse_float_list(text: str) -> list[float]: return [float(x.strip()) for x in text.split(",") if x.strip()] def load_model(ckpt_root: Path, ckpt_name: str, device: str): cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) data_path = Path(cfg.get("data_path") or cfg["data_paths"][0]) train_meta = json.loads((data_path / "train" / "dataset.json").read_text()) arch_cfg = dict(cfg["arch"]) arch_cfg.update( batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"], vocab_size=train_meta["vocab_size"], num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], causal=False, ) model = TinyRecursiveReasoningModel_ACTV1(arch_cfg) state = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in state.items()} missing, unexpected = model.load_state_dict(stripped, strict=False) print(f"[load] {ckpt_root.name}/{ckpt_name} missing={len(missing)} unexpected={len(unexpected)}", flush=True) if missing[:4]: print(f"[load] sample missing: {missing[:4]}", flush=True) if unexpected[:4]: print(f"[load] sample unexpected: {unexpected[:4]}", flush=True) model.to(device).eval() return model, cfg, data_path def load_test_samples(data_path: Path, n_samples: int, seed: int): rng = np.random.default_rng(seed) inputs = np.load(data_path / "test" / "all__inputs.npy") labels = np.load(data_path / "test" / "all__labels.npy") puzzle_ids = np.load(data_path / "test" / "all__puzzle_identifiers.npy") n = min(n_samples, len(inputs)) idx = rng.choice(len(inputs), size=n, replace=False) return { "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), "labels": torch.from_numpy(labels[idx].astype(np.int32)), "puzzle_identifiers": torch.from_numpy(puzzle_ids[idx].astype(np.int32)), "idx": idx, } def batch_slice(samples: dict[str, Any], start: int, end: int, device: str): return { k: v[start:end].to(device, non_blocking=True) for k, v in samples.items() if k in ("inputs", "labels", "puzzle_identifiers") } def repeat_batch(batch: dict[str, torch.Tensor], repeats: int): if repeats == 1: return batch return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()} def sample_unit_noise_like(tensor: torch.Tensor, generator: torch.Generator, distribution: str): if distribution == "uniform": noise = 2.0 * torch.rand(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) - 1.0 noise = noise * math.sqrt(3.0) else: noise = torch.randn(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) return noise.to(tensor.dtype) def apply_initial_noise( inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, sigma: float, perturb: str, generator: torch.Generator, distribution: str, ): if sigma <= 0: return inner z_h, z_l = inner.z_H, inner.z_L if perturb in ("h", "both"): z_h = z_h + sigma * sample_unit_noise_like(z_h, generator, distribution) if perturb in ("l", "both"): z_l = z_l + sigma * sample_unit_noise_like(z_l, generator, distribution) return replace(inner, z_H=z_h, z_L=z_l) def correctness(logits: torch.Tensor, labels: torch.Tensor): preds = logits.argmax(dim=-1) mask = labels != IGNORE_LABEL_ID exact = torch.where(mask, preds == labels, True).all(dim=-1) denom = mask.sum(-1).clamp_min(1) token_acc = ((preds == labels) & mask).sum(-1).float() / denom.float() return exact, token_acc @torch.inference_mode() def eval_sigma( model, batch: dict[str, torch.Tensor], sigma: float, rollouts: int, perturb: str, distribution: str, generator: torch.Generator, ): expanded = repeat_batch(batch, rollouts) total = expanded["inputs"].shape[0] with torch.device(expanded["inputs"].device): carry = model.initial_carry(expanded) reset = torch.ones(total, device=expanded["inputs"].device, dtype=torch.bool) inner = model.inner.reset_carry(reset, carry.inner_carry) inner = apply_initial_noise(inner, sigma, perturb, generator, distribution) logits = None for _ in range(model.config.halt_max_steps): inner, logits, _q = model.inner(inner, expanded) assert logits is not None exact, token_acc = correctness(logits, expanded["labels"]) base_bsz = batch["inputs"].shape[0] return exact.view(base_bsz, rollouts), token_acc.view(base_bsz, rollouts) def summarize_sigma(exact: torch.Tensor, token_acc: torch.Tensor) -> dict[str, float]: correct_counts = exact.float().sum(dim=1) rollouts = exact.shape[1] return { "mean_rollout_exact": exact.float().mean().item(), "mean_rollout_token_acc": token_acc.mean().item(), "pass_at_k": exact.any(dim=1).float().mean().item(), "all_k": exact.all(dim=1).float().mean().item(), "correct_count_mean": correct_counts.mean().item(), "correct_count_std": correct_counts.std(unbiased=False).item(), "correct_count_q10": torch.quantile(correct_counts, 0.10).item(), "correct_count_q50": torch.quantile(correct_counts, 0.50).item(), "correct_count_q90": torch.quantile(correct_counts, 0.90).item(), "zero_frac": (correct_counts == 0).float().mean().item(), "full_frac": (correct_counts == rollouts).float().mean().item(), } def write_summary(path: Path, rows: list[dict[str, Any]]) -> None: keys = list(rows[0]) for row in rows[1:]: for key in row: if key not in keys: keys.append(key) with path.open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=keys) writer.writeheader() writer.writerows(rows) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--ckpt-root", required=True) parser.add_argument("--ckpt-name", required=True) parser.add_argument("--label", required=True) parser.add_argument("--n-samples", type=int, default=2000) parser.add_argument("--batch-size", type=int, default=32, help="Number of original problems per batch; expanded batch is batch_size * rollouts.") parser.add_argument("--rollouts", type=int, default=8) parser.add_argument("--sigmas", default="0,3e-5,1e-4,3e-4,1e-3,3e-3,1e-2,3e-2") parser.add_argument("--perturb", choices=["h", "l", "both"], default="both") parser.add_argument("--noise-distribution", choices=["gaussian", "uniform"], default="gaussian") parser.add_argument("--seed", type=int, default=20260605) parser.add_argument("--out-prefix", required=True) args = parser.parse_args() device = "cuda" sigmas = parse_float_list(args.sigmas) torch.manual_seed(args.seed) generator = torch.Generator(device=device).manual_seed(args.seed + 101) model, cfg, data_path = load_model(Path(args.ckpt_root), args.ckpt_name, device) samples = load_test_samples(data_path, args.n_samples, args.seed) n = len(samples["inputs"]) print( f"[run] label={args.label} n={n} rollouts={args.rollouts} " f"batch={args.batch_size} sigmas={sigmas}", flush=True, ) rows: list[dict[str, Any]] = [] all_exact = [] all_token = [] for sigma in sigmas: exact_parts = [] token_parts = [] for start in range(0, n, args.batch_size): end = min(start + args.batch_size, n) batch = batch_slice(samples, start, end, device) exact, token_acc = eval_sigma( model=model, batch=batch, sigma=sigma, rollouts=args.rollouts, perturb=args.perturb, distribution=args.noise_distribution, generator=generator, ) exact_parts.append(exact.cpu()) token_parts.append(token_acc.cpu()) if end == n or (end // args.batch_size) % 10 == 0: print(f" sigma={sigma:g} [{end}/{n}]", flush=True) exact_all = torch.cat(exact_parts, dim=0) token_all = torch.cat(token_parts, dim=0) row: dict[str, Any] = { "label": args.label, "sigma": sigma, "n_samples": n, "rollouts": args.rollouts, "ckpt_root": str(Path(args.ckpt_root)), "ckpt_name": args.ckpt_name, "perturb": args.perturb, "noise_distribution": args.noise_distribution, **summarize_sigma(exact_all, token_all), } rows.append(row) all_exact.append(exact_all.numpy()) all_token.append(token_all.numpy()) print( f" sigma={sigma:g} mean={row['mean_rollout_exact']:.4f} " f"pass@K={row['pass_at_k']:.4f} allK={row['all_k']:.4f}", flush=True, ) out_prefix = Path(args.out_prefix) out_prefix.parent.mkdir(parents=True, exist_ok=True) write_summary(out_prefix.with_suffix(".summary.csv"), rows) meta = { "args": vars(args), "data_path": str(data_path), "config_global_batch_size": cfg.get("global_batch_size"), "sigmas": sigmas, "n_samples": n, } out_prefix.with_suffix(".meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True)) np.savez_compressed( out_prefix.with_suffix(".npz"), idx=samples["idx"], sigmas=np.asarray(sigmas, dtype=np.float32), exact=np.stack(all_exact, axis=0), token_acc=np.stack(all_token, axis=0), meta_json=np.asarray(json.dumps(meta, sort_keys=True)), ) print(f"[done] {out_prefix}.summary.csv", flush=True) if __name__ == "__main__": main()