"""Directional late-state perturbation using finite-difference Lyapunov search. At a chosen recurrent step, sample several unit tangent directions, propagate small shadow trajectories through the remaining deterministic dynamics, choose the direction with maximal final hidden-state expansion, then perturb along that selected direction with +/- sigma and measure answer robustness. This is a practical finite-difference proxy for perturbing along a local top Lyapunov direction without paying exact JVP costs. """ from __future__ import annotations import argparse import csv import json import math import sys from dataclasses import replace from pathlib import Path from typing import Any import numpy as np import torch import yaml TRM_DIR = Path("/home/yurenh2/rrm/trm") sys.path.insert(0, str(TRM_DIR)) from models.recursive_reasoning.trm import ( # noqa: E402 TinyRecursiveReasoningModel_ACTV1, TinyRecursiveReasoningModel_ACTV1InnerCarry, ) IGNORE_LABEL_ID = -100 def parse_float_list(text: str) -> list[float]: return [float(x.strip()) for x in text.split(",") if x.strip()] def parse_int_list(text: str) -> list[int]: return [int(x.strip()) for x in text.split(",") if x.strip()] def load_model(ckpt_root: Path, ckpt_name: str, device: str): cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) data_path = Path(cfg.get("data_path") or cfg["data_paths"][0]) train_meta = json.loads((data_path / "train" / "dataset.json").read_text()) arch_cfg = dict(cfg["arch"]) arch_cfg.update( batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"], vocab_size=train_meta["vocab_size"], num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], causal=False, ) model = TinyRecursiveReasoningModel_ACTV1(arch_cfg) state = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in state.items()} missing, unexpected = model.load_state_dict(stripped, strict=False) print(f"[load] {ckpt_root.name}/{ckpt_name} missing={len(missing)} unexpected={len(unexpected)}", flush=True) model.to(device).eval() return model, cfg, data_path def load_test_samples(data_path: Path, n_samples: int, seed: int): rng = np.random.default_rng(seed) inputs = np.load(data_path / "test" / "all__inputs.npy") labels = np.load(data_path / "test" / "all__labels.npy") puzzle_ids = np.load(data_path / "test" / "all__puzzle_identifiers.npy") n = min(n_samples, len(inputs)) idx = rng.choice(len(inputs), size=n, replace=False) return { "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), "labels": torch.from_numpy(labels[idx].astype(np.int32)), "puzzle_identifiers": torch.from_numpy(puzzle_ids[idx].astype(np.int32)), "idx": idx, } def batch_slice(samples: dict[str, Any], start: int, end: int, device: str): return { k: v[start:end].to(device, non_blocking=True) for k, v in samples.items() if k in ("inputs", "labels", "puzzle_identifiers") } def repeat_batch(batch: dict[str, torch.Tensor], repeats: int): if repeats == 1: return batch return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()} def cat_batches(a: dict[str, torch.Tensor], b: dict[str, torch.Tensor]): return {k: torch.cat([a[k], b[k]], dim=0) for k in a} def cat_inner(a: TinyRecursiveReasoningModel_ACTV1InnerCarry, b: TinyRecursiveReasoningModel_ACTV1InnerCarry): return TinyRecursiveReasoningModel_ACTV1InnerCarry( z_H=torch.cat([a.z_H, b.z_H], dim=0), z_L=torch.cat([a.z_L, b.z_L], dim=0), ) def split_inner(inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, n_main: int): return ( TinyRecursiveReasoningModel_ACTV1InnerCarry( z_H=inner.z_H[:n_main].contiguous(), z_L=inner.z_L[:n_main].contiguous(), ), TinyRecursiveReasoningModel_ACTV1InnerCarry( z_H=inner.z_H[n_main:].contiguous(), z_L=inner.z_L[n_main:].contiguous(), ), ) def correctness(logits: torch.Tensor, labels: torch.Tensor): preds = logits.argmax(dim=-1) mask = labels != IGNORE_LABEL_ID exact = torch.where(mask, preds == labels, True).all(dim=-1) denom = mask.sum(-1).clamp_min(1) token_acc = ((preds == labels) & mask).sum(-1).float() / denom.float() return exact, token_acc def rand_unit_dirs( inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, candidates: int, generator: torch.Generator, ): bsz = inner.z_H.shape[0] h_dirs = torch.randn( (bsz, candidates) + tuple(inner.z_H.shape[1:]), device=inner.z_H.device, dtype=torch.float32, generator=generator, ) l_dirs = torch.randn( (bsz, candidates) + tuple(inner.z_L.shape[1:]), device=inner.z_L.device, dtype=torch.float32, generator=generator, ) norm = torch.sqrt(h_dirs.flatten(2).square().sum(-1) + l_dirs.flatten(2).square().sum(-1)).clamp_min(1e-30) h_view = (bsz, candidates) + (1,) * (h_dirs.ndim - 2) l_view = (bsz, candidates) + (1,) * (l_dirs.ndim - 2) return (h_dirs / norm.view(h_view)).to(inner.z_H.dtype), (l_dirs / norm.view(l_view)).to(inner.z_L.dtype) def make_shadow_inner( inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, h_dirs: torch.Tensor, l_dirs: torch.Tensor, eps: float, ): bsz, candidates = h_dirs.shape[:2] z_h = inner.z_H[:, None] + eps * h_dirs z_l = inner.z_L[:, None] + eps * l_dirs return TinyRecursiveReasoningModel_ACTV1InnerCarry( z_H=z_h.reshape((bsz * candidates,) + tuple(inner.z_H.shape[1:])).detach(), z_L=z_l.reshape((bsz * candidates,) + tuple(inner.z_L.shape[1:])).detach(), ) def separation( main: TinyRecursiveReasoningModel_ACTV1InnerCarry, shadow: TinyRecursiveReasoningModel_ACTV1InnerCarry, candidates: int, ): bsz = main.z_H.shape[0] sh = shadow.z_H.reshape((bsz, candidates) + tuple(main.z_H.shape[1:])).float() sl = shadow.z_L.reshape((bsz, candidates) + tuple(main.z_L.shape[1:])).float() dh = (sh - main.z_H[:, None].float()).flatten(2) dl = (sl - main.z_L[:, None].float()).flatten(2) return torch.sqrt(dh.square().sum(-1) + dl.square().sum(-1)).clamp_min(1e-30) def gather_dirs(h_dirs: torch.Tensor, l_dirs: torch.Tensor, idx: torch.Tensor): bsz = h_dirs.shape[0] arange = torch.arange(bsz, device=h_dirs.device) return h_dirs[arange, idx].contiguous(), l_dirs[arange, idx].contiguous() def perturb_inner( inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, h_dir: torch.Tensor, l_dir: torch.Tensor, sigma: float, sign: float, ): return replace( inner, z_H=inner.z_H + (sign * sigma) * h_dir.to(inner.z_H.dtype), z_L=inner.z_L + (sign * sigma) * l_dir.to(inner.z_L.dtype), ) @torch.inference_mode() def warmup_inner(model, batch: dict[str, torch.Tensor], after: int): bsz = batch["inputs"].shape[0] with torch.device(batch["inputs"].device): carry = model.initial_carry(batch) reset = torch.ones(bsz, device=batch["inputs"].device, dtype=torch.bool) inner = model.inner.reset_carry(reset, carry.inner_carry) logits = None for _ in range(after): inner, logits, _q = model.inner(inner, batch) return inner, logits @torch.inference_mode() def search_direction( model, warm: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: dict[str, torch.Tensor], after: int, candidates: int, fd_eps: float, generator: torch.Generator, ): bsz = batch["inputs"].shape[0] remaining = model.config.halt_max_steps - after h_dirs, l_dirs = rand_unit_dirs(warm, candidates, generator) main = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=warm.z_H.detach(), z_L=warm.z_L.detach()) shadow = make_shadow_inner(main, h_dirs, l_dirs, fd_eps) combined = cat_inner(main, shadow) combined_batch = cat_batches(batch, repeat_batch(batch, candidates)) logits = None for _ in range(remaining): combined, logits, _q = model.inner(combined, combined_batch) assert logits is not None main_final, shadow_final = split_inner(combined, bsz) main_logits = logits[:bsz] sep = separation(main_final, shadow_final, candidates) best_idx = sep.argmax(dim=1) best_sep = sep.gather(1, best_idx[:, None]).squeeze(1) best_h, best_l = gather_dirs(h_dirs, l_dirs, best_idx) growth = torch.log(best_sep / fd_eps).float() / max(remaining, 1) clean_exact, clean_token = correctness(main_logits, batch["labels"]) return best_h, best_l, growth, clean_exact, clean_token @torch.inference_mode() def eval_directional_sigma( model, warm: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: dict[str, torch.Tensor], after: int, h_dir: torch.Tensor, l_dir: torch.Tensor, sigma: float, ): remaining = model.config.halt_max_steps - after plus = perturb_inner(warm, h_dir, l_dir, sigma, +1.0) minus = perturb_inner(warm, h_dir, l_dir, sigma, -1.0) inner = cat_inner(plus, minus) combined_batch = cat_batches(batch, batch) logits = None for _ in range(remaining): inner, logits, _q = model.inner(inner, combined_batch) assert logits is not None exact, token = correctness(logits, combined_batch["labels"]) bsz = batch["inputs"].shape[0] return exact.view(2, bsz).transpose(0, 1).contiguous(), token.view(2, bsz).transpose(0, 1).contiguous() def summarize(exact: torch.Tensor, token: torch.Tensor, clean_exact: torch.Tensor, growth: torch.Tensor): clean_success = clean_exact.bool() clean_fail = ~clean_success both = exact.all(dim=1) either = exact.any(dim=1) out = { "clean_acc": clean_success.float().mean().item(), "mean_sign_exact": exact.float().mean().item(), "mean_sign_token_acc": token.mean().item(), "worst_sign_exact": both.float().mean().item(), "best_sign_exact": either.float().mean().item(), "selected_growth_mean": growth.mean().item(), "selected_growth_q90": torch.quantile(growth, 0.90).item(), } if clean_success.any().item(): out["retain_mean_on_clean_success"] = exact[clean_success].float().mean().item() out["retain_worst_on_clean_success"] = both[clean_success].float().mean().item() else: out["retain_mean_on_clean_success"] = float("nan") out["retain_worst_on_clean_success"] = float("nan") if clean_fail.any().item(): out["rescue_mean_on_clean_fail"] = exact[clean_fail].float().mean().item() out["rescue_best_on_clean_fail"] = either[clean_fail].float().mean().item() else: out["rescue_mean_on_clean_fail"] = float("nan") out["rescue_best_on_clean_fail"] = float("nan") return out def write_summary(path: Path, rows: list[dict[str, Any]]) -> None: keys = list(rows[0]) for row in rows[1:]: for key in row: if key not in keys: keys.append(key) with path.open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=keys) writer.writeheader() writer.writerows(rows) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--ckpt-root", required=True) parser.add_argument("--ckpt-name", required=True) parser.add_argument("--label", required=True) parser.add_argument("--n-samples", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--candidates", type=int, default=8) parser.add_argument("--fd-eps", type=float, default=1e-3) parser.add_argument("--sigmas", default="0,0.001,0.003,0.01,0.03,0.1") parser.add_argument("--perturb-afters", default="0,4,8,12") parser.add_argument("--seed", type=int, default=20260607) parser.add_argument("--out-prefix", required=True) args = parser.parse_args() device = "cuda" sigmas = parse_float_list(args.sigmas) if not any(abs(s) <= 1e-12 for s in sigmas): sigmas = [0.0] + sigmas afters = parse_int_list(args.perturb_afters) torch.manual_seed(args.seed) generator = torch.Generator(device=device).manual_seed(args.seed + 17) model, cfg, data_path = load_model(Path(args.ckpt_root), args.ckpt_name, device) samples = load_test_samples(data_path, args.n_samples, args.seed) n = len(samples["inputs"]) print( f"[run] label={args.label} n={n} batch={args.batch_size} candidates={args.candidates} " f"afters={afters} sigmas={sigmas}", flush=True, ) rows: list[dict[str, Any]] = [] exact_store = [] token_store = [] growth_store = [] for after in afters: sigma_exact = {sigma: [] for sigma in sigmas} sigma_token = {sigma: [] for sigma in sigmas} growth_parts = [] for start in range(0, n, args.batch_size): end = min(start + args.batch_size, n) batch = batch_slice(samples, start, end, device) warm, _ = warmup_inner(model, batch, after) h_dir, l_dir, growth, _search_clean_exact, _search_clean_token = search_direction( model, warm, batch, after, args.candidates, args.fd_eps, generator ) growth_parts.append(growth.cpu()) for sigma in sigmas: exact, token = eval_directional_sigma(model, warm, batch, after, h_dir, l_dir, sigma) sigma_exact[sigma].append(exact.cpu()) sigma_token[sigma].append(token.cpu()) if end == n or (end // args.batch_size) % 10 == 0: print(f" after={after} [{end}/{n}]", flush=True) growth_all = torch.cat(growth_parts, dim=0) growth_store.append(growth_all.numpy()) exact_by_sigma = {sigma: torch.cat(sigma_exact[sigma], dim=0) for sigma in sigmas} token_by_sigma = {sigma: torch.cat(sigma_token[sigma], dim=0) for sigma in sigmas} clean_all = exact_by_sigma[0.0][:, 0].clone() print(f" after={after} clean grouping done clean_acc={clean_all.float().mean().item():.4f}", flush=True) for sigma in sigmas: exact_all = exact_by_sigma[sigma] token_all = token_by_sigma[sigma] row: dict[str, Any] = { "label": args.label, "perturb_after": after, "sigma": sigma, "n_samples": n, "candidates": args.candidates, "fd_eps": args.fd_eps, "ckpt_root": str(Path(args.ckpt_root)), "ckpt_name": args.ckpt_name, **summarize(exact_all, token_all, clean_all, growth_all), } rows.append(row) exact_store.append(exact_all.numpy()) token_store.append(token_all.numpy()) print( f" after={after} sigma={sigma:g} clean={row['clean_acc']:.4f} " f"mean={row['mean_sign_exact']:.4f} worst={row['worst_sign_exact']:.4f} " f"retain_worst={row['retain_worst_on_clean_success']:.4f} " f"rescue_best={row['rescue_best_on_clean_fail']:.4f}", flush=True, ) out_prefix = Path(args.out_prefix) out_prefix.parent.mkdir(parents=True, exist_ok=True) write_summary(out_prefix.with_suffix(".summary.csv"), rows) meta = { "args": vars(args), "data_path": str(data_path), "config_global_batch_size": cfg.get("global_batch_size"), "sigmas": sigmas, "perturb_afters": afters, "n_samples": n, } out_prefix.with_suffix(".meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True)) np.savez_compressed( out_prefix.with_suffix(".npz"), idx=samples["idx"], sigmas=np.asarray(sigmas, dtype=np.float32), perturb_afters=np.asarray(afters, dtype=np.int32), exact=np.stack(exact_store, axis=0), token_acc=np.stack(token_store, axis=0), selected_growth=np.stack(growth_store, axis=0), meta_json=np.asarray(json.dumps(meta, sort_keys=True)), ) print(f"[done] {out_prefix}.summary.csv", flush=True) if __name__ == "__main__": main()