"""Inference-time robustness to late recurrent-state perturbations. For each sample, run the model cleanly for `perturb_after` inner iterations, perturb z_H/z_L once, then continue the deterministic recurrent rollout. This stress-tests the attractor basin along the inference trajectory rather than only at the initial recurrent state. """ from __future__ import annotations import argparse import csv import json import math import sys from dataclasses import replace from pathlib import Path from typing import Any import numpy as np import torch import yaml TRM_DIR = Path("/home/yurenh2/rrm/trm") sys.path.insert(0, str(TRM_DIR)) from models.recursive_reasoning.trm import ( # noqa: E402 TinyRecursiveReasoningModel_ACTV1, TinyRecursiveReasoningModel_ACTV1InnerCarry, ) IGNORE_LABEL_ID = -100 def parse_float_list(text: str) -> list[float]: return [float(x.strip()) for x in text.split(",") if x.strip()] def parse_int_list(text: str) -> list[int]: return [int(x.strip()) for x in text.split(",") if x.strip()] def is_zero(value: float) -> bool: return abs(value) <= 1e-12 def load_model(ckpt_root: Path, ckpt_name: str, device: str): cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) data_path = Path(cfg.get("data_path") or cfg["data_paths"][0]) train_meta = json.loads((data_path / "train" / "dataset.json").read_text()) arch_cfg = dict(cfg["arch"]) arch_cfg.update( batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"], vocab_size=train_meta["vocab_size"], num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], causal=False, ) model = TinyRecursiveReasoningModel_ACTV1(arch_cfg) state = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in state.items()} missing, unexpected = model.load_state_dict(stripped, strict=False) print(f"[load] {ckpt_root.name}/{ckpt_name} missing={len(missing)} unexpected={len(unexpected)}", flush=True) model.to(device).eval() return model, cfg, data_path def load_test_samples(data_path: Path, n_samples: int, seed: int): rng = np.random.default_rng(seed) inputs = np.load(data_path / "test" / "all__inputs.npy") labels = np.load(data_path / "test" / "all__labels.npy") puzzle_ids = np.load(data_path / "test" / "all__puzzle_identifiers.npy") n = min(n_samples, len(inputs)) idx = rng.choice(len(inputs), size=n, replace=False) return { "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), "labels": torch.from_numpy(labels[idx].astype(np.int32)), "puzzle_identifiers": torch.from_numpy(puzzle_ids[idx].astype(np.int32)), "idx": idx, } def batch_slice(samples: dict[str, Any], start: int, end: int, device: str): return { k: v[start:end].to(device, non_blocking=True) for k, v in samples.items() if k in ("inputs", "labels", "puzzle_identifiers") } def repeat_batch(batch: dict[str, torch.Tensor], repeats: int): if repeats == 1: return batch return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()} def sample_unit_noise_like(tensor: torch.Tensor, generator: torch.Generator, distribution: str): if distribution == "uniform": noise = 2.0 * torch.rand(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) - 1.0 noise = noise * math.sqrt(3.0) else: noise = torch.randn(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) return noise.to(tensor.dtype) def apply_noise( inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, sigma: float, perturb: str, generator: torch.Generator, distribution: str, ): if sigma <= 0: return inner z_h, z_l = inner.z_H, inner.z_L if perturb in ("h", "both"): z_h = z_h + sigma * sample_unit_noise_like(z_h, generator, distribution) if perturb in ("l", "both"): z_l = z_l + sigma * sample_unit_noise_like(z_l, generator, distribution) return replace(inner, z_H=z_h, z_L=z_l) def correctness(logits: torch.Tensor, labels: torch.Tensor): preds = logits.argmax(dim=-1) mask = labels != IGNORE_LABEL_ID exact = torch.where(mask, preds == labels, True).all(dim=-1) denom = mask.sum(-1).clamp_min(1) token_acc = ((preds == labels) & mask).sum(-1).float() / denom.float() return exact, token_acc @torch.inference_mode() def eval_condition( model, batch: dict[str, torch.Tensor], sigma: float, rollouts: int, perturb_after: int, perturb: str, distribution: str, generator: torch.Generator, ): expanded = repeat_batch(batch, rollouts) total = expanded["inputs"].shape[0] steps = model.config.halt_max_steps warmup = min(max(perturb_after, 0), steps - 1) with torch.device(expanded["inputs"].device): carry = model.initial_carry(expanded) reset = torch.ones(total, device=expanded["inputs"].device, dtype=torch.bool) inner = model.inner.reset_carry(reset, carry.inner_carry) logits = None for _ in range(warmup): inner, logits, _q = model.inner(inner, expanded) inner = apply_noise(inner, sigma, perturb, generator, distribution) for _ in range(warmup, steps): inner, logits, _q = model.inner(inner, expanded) assert logits is not None exact, token_acc = correctness(logits, expanded["labels"]) base_bsz = batch["inputs"].shape[0] return exact.view(base_bsz, rollouts), token_acc.view(base_bsz, rollouts) def summarize(exact: torch.Tensor, token_acc: torch.Tensor, clean_exact: torch.Tensor) -> dict[str, float]: correct_counts = exact.float().sum(dim=1) rollouts = exact.shape[1] clean_success = clean_exact.bool() clean_fail = ~clean_success out = { "mean_rollout_exact": exact.float().mean().item(), "mean_rollout_token_acc": token_acc.mean().item(), "pass_at_k": exact.any(dim=1).float().mean().item(), "all_k": exact.all(dim=1).float().mean().item(), "correct_count_mean": correct_counts.mean().item(), "correct_count_std": correct_counts.std(unbiased=False).item(), "zero_frac": (correct_counts == 0).float().mean().item(), "full_frac": (correct_counts == rollouts).float().mean().item(), "clean_acc": clean_success.float().mean().item(), } if clean_success.any().item(): out["retain_mean_on_clean_success"] = exact[clean_success].float().mean().item() out["allK_on_clean_success"] = exact[clean_success].all(dim=1).float().mean().item() else: out["retain_mean_on_clean_success"] = float("nan") out["allK_on_clean_success"] = float("nan") if clean_fail.any().item(): out["rescue_mean_on_clean_fail"] = exact[clean_fail].float().mean().item() out["passK_on_clean_fail"] = exact[clean_fail].any(dim=1).float().mean().item() else: out["rescue_mean_on_clean_fail"] = float("nan") out["passK_on_clean_fail"] = float("nan") return out def write_summary(path: Path, rows: list[dict[str, Any]]) -> None: keys = list(rows[0]) for row in rows[1:]: for key in row: if key not in keys: keys.append(key) with path.open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=keys) writer.writeheader() writer.writerows(rows) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--ckpt-root", required=True) parser.add_argument("--ckpt-name", required=True) parser.add_argument("--label", required=True) parser.add_argument("--n-samples", type=int, default=3000) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--rollouts", type=int, default=8) parser.add_argument("--sigmas", default="0,0.001,0.003,0.01,0.03,0.1") parser.add_argument("--perturb-afters", default="0,4,8,12,15") parser.add_argument("--perturb", choices=["h", "l", "both"], default="both") parser.add_argument("--noise-distribution", choices=["gaussian", "uniform"], default="gaussian") parser.add_argument("--seed", type=int, default=20260606) parser.add_argument("--out-prefix", required=True) args = parser.parse_args() device = "cuda" sigmas = parse_float_list(args.sigmas) perturb_afters = parse_int_list(args.perturb_afters) torch.manual_seed(args.seed) generator = torch.Generator(device=device).manual_seed(args.seed + 101) model, cfg, data_path = load_model(Path(args.ckpt_root), args.ckpt_name, device) samples = load_test_samples(data_path, args.n_samples, args.seed) n = len(samples["inputs"]) print( f"[run] label={args.label} n={n} rollouts={args.rollouts} batch={args.batch_size} " f"afters={perturb_afters} sigmas={sigmas}", flush=True, ) rows: list[dict[str, Any]] = [] all_exact = [] all_token = [] if not any(is_zero(s) for s in sigmas): sigmas = [0.0] + sigmas for after in perturb_afters: clean_exact = None pending: list[tuple[dict[str, Any], torch.Tensor, torch.Tensor]] = [] for sigma in sigmas: exact_parts = [] token_parts = [] for start in range(0, n, args.batch_size): end = min(start + args.batch_size, n) batch = batch_slice(samples, start, end, device) exact, token_acc = eval_condition( model, batch, sigma, args.rollouts, after, args.perturb, args.noise_distribution, generator ) exact_parts.append(exact.cpu()) token_parts.append(token_acc.cpu()) if end == n or (end // args.batch_size) % 10 == 0: print(f" after={after} sigma={sigma:g} [{end}/{n}]", flush=True) exact_all = torch.cat(exact_parts, dim=0) token_all = torch.cat(token_parts, dim=0) if is_zero(sigma): clean_exact = exact_all[:, 0].clone() print(f" after={after} clean grouping done clean_acc={clean_exact.float().mean().item():.4f}", flush=True) if clean_exact is None: pending.append(({"sigma": sigma}, exact_all, token_all)) continue row: dict[str, Any] = { "label": args.label, "perturb_after": after, "sigma": sigma, "n_samples": n, "rollouts": args.rollouts, "ckpt_root": str(Path(args.ckpt_root)), "ckpt_name": args.ckpt_name, "perturb": args.perturb, "noise_distribution": args.noise_distribution, **summarize(exact_all, token_all, clean_exact), } rows.append(row) all_exact.append(exact_all.numpy()) all_token.append(token_all.numpy()) print( f" after={after} sigma={sigma:g} mean={row['mean_rollout_exact']:.4f} " f"retain={row['retain_mean_on_clean_success']:.4f} " f"rescue={row['rescue_mean_on_clean_fail']:.4f}", flush=True, ) if pending: raise RuntimeError("sigma=0 must be evaluated before nonzero sigmas") out_prefix = Path(args.out_prefix) out_prefix.parent.mkdir(parents=True, exist_ok=True) write_summary(out_prefix.with_suffix(".summary.csv"), rows) meta = { "args": vars(args), "data_path": str(data_path), "config_global_batch_size": cfg.get("global_batch_size"), "sigmas": sigmas, "perturb_afters": perturb_afters, "n_samples": n, } out_prefix.with_suffix(".meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True)) np.savez_compressed( out_prefix.with_suffix(".npz"), idx=samples["idx"], sigmas=np.asarray(sigmas, dtype=np.float32), perturb_afters=np.asarray(perturb_afters, dtype=np.int32), exact=np.stack(all_exact, axis=0), token_acc=np.stack(all_token, axis=0), meta_json=np.asarray(json.dumps(meta, sort_keys=True)), ) print(f"[done] {out_prefix}.summary.csv", flush=True) if __name__ == "__main__": main()