summaryrefslogtreecommitdiff
path: root/research/flossing/initial_perturb_robustness.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/initial_perturb_robustness.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/initial_perturb_robustness.py')
-rw-r--r--research/flossing/initial_perturb_robustness.py286
1 files changed, 286 insertions, 0 deletions
diff --git a/research/flossing/initial_perturb_robustness.py b/research/flossing/initial_perturb_robustness.py
new file mode 100644
index 0000000..e652080
--- /dev/null
+++ b/research/flossing/initial_perturb_robustness.py
@@ -0,0 +1,286 @@
+"""Inference-time robustness to initial recurrent-state perturbations.
+
+This tests whether trajectory-perturbation training enlarged the correct
+attractor basin. Unlike PTRM-style rollout noise, the perturbation is applied
+once after resetting z_H/z_L, matching the training augmentation mechanism.
+"""
+from __future__ import annotations
+
+import argparse
+import csv
+import json
+import math
+import sys
+from dataclasses import replace
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+import yaml
+
+
+TRM_DIR = Path("/home/yurenh2/rrm/trm")
+sys.path.insert(0, str(TRM_DIR))
+
+from models.recursive_reasoning.trm import ( # noqa: E402
+ TinyRecursiveReasoningModel_ACTV1,
+ TinyRecursiveReasoningModel_ACTV1InnerCarry,
+)
+
+
+IGNORE_LABEL_ID = -100
+
+
+def parse_float_list(text: str) -> list[float]:
+ return [float(x.strip()) for x in text.split(",") if x.strip()]
+
+
+def load_model(ckpt_root: Path, ckpt_name: str, device: str):
+ cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text())
+ data_path = Path(cfg.get("data_path") or cfg["data_paths"][0])
+ train_meta = json.loads((data_path / "train" / "dataset.json").read_text())
+
+ arch_cfg = dict(cfg["arch"])
+ arch_cfg.update(
+ batch_size=cfg["global_batch_size"],
+ seq_len=train_meta["seq_len"],
+ vocab_size=train_meta["vocab_size"],
+ num_puzzle_identifiers=train_meta["num_puzzle_identifiers"],
+ causal=False,
+ )
+
+ model = TinyRecursiveReasoningModel_ACTV1(arch_cfg)
+ state = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True)
+ stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in state.items()}
+ missing, unexpected = model.load_state_dict(stripped, strict=False)
+ print(f"[load] {ckpt_root.name}/{ckpt_name} missing={len(missing)} unexpected={len(unexpected)}", flush=True)
+ if missing[:4]:
+ print(f"[load] sample missing: {missing[:4]}", flush=True)
+ if unexpected[:4]:
+ print(f"[load] sample unexpected: {unexpected[:4]}", flush=True)
+ model.to(device).eval()
+ return model, cfg, data_path
+
+
+def load_test_samples(data_path: Path, n_samples: int, seed: int):
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "test" / "all__inputs.npy")
+ labels = np.load(data_path / "test" / "all__labels.npy")
+ puzzle_ids = np.load(data_path / "test" / "all__puzzle_identifiers.npy")
+
+ n = min(n_samples, len(inputs))
+ idx = rng.choice(len(inputs), size=n, replace=False)
+ return {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)),
+ "puzzle_identifiers": torch.from_numpy(puzzle_ids[idx].astype(np.int32)),
+ "idx": idx,
+ }
+
+
+def batch_slice(samples: dict[str, Any], start: int, end: int, device: str):
+ return {
+ k: v[start:end].to(device, non_blocking=True)
+ for k, v in samples.items()
+ if k in ("inputs", "labels", "puzzle_identifiers")
+ }
+
+
+def repeat_batch(batch: dict[str, torch.Tensor], repeats: int):
+ if repeats == 1:
+ return batch
+ return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()}
+
+
+def sample_unit_noise_like(tensor: torch.Tensor, generator: torch.Generator, distribution: str):
+ if distribution == "uniform":
+ noise = 2.0 * torch.rand(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) - 1.0
+ noise = noise * math.sqrt(3.0)
+ else:
+ noise = torch.randn(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator)
+ return noise.to(tensor.dtype)
+
+
+def apply_initial_noise(
+ inner: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ sigma: float,
+ perturb: str,
+ generator: torch.Generator,
+ distribution: str,
+):
+ if sigma <= 0:
+ return inner
+ z_h, z_l = inner.z_H, inner.z_L
+ if perturb in ("h", "both"):
+ z_h = z_h + sigma * sample_unit_noise_like(z_h, generator, distribution)
+ if perturb in ("l", "both"):
+ z_l = z_l + sigma * sample_unit_noise_like(z_l, generator, distribution)
+ return replace(inner, z_H=z_h, z_L=z_l)
+
+
+def correctness(logits: torch.Tensor, labels: torch.Tensor):
+ preds = logits.argmax(dim=-1)
+ mask = labels != IGNORE_LABEL_ID
+ exact = torch.where(mask, preds == labels, True).all(dim=-1)
+ denom = mask.sum(-1).clamp_min(1)
+ token_acc = ((preds == labels) & mask).sum(-1).float() / denom.float()
+ return exact, token_acc
+
+
+@torch.inference_mode()
+def eval_sigma(
+ model,
+ batch: dict[str, torch.Tensor],
+ sigma: float,
+ rollouts: int,
+ perturb: str,
+ distribution: str,
+ generator: torch.Generator,
+):
+ expanded = repeat_batch(batch, rollouts)
+ total = expanded["inputs"].shape[0]
+ with torch.device(expanded["inputs"].device):
+ carry = model.initial_carry(expanded)
+ reset = torch.ones(total, device=expanded["inputs"].device, dtype=torch.bool)
+ inner = model.inner.reset_carry(reset, carry.inner_carry)
+ inner = apply_initial_noise(inner, sigma, perturb, generator, distribution)
+
+ logits = None
+ for _ in range(model.config.halt_max_steps):
+ inner, logits, _q = model.inner(inner, expanded)
+
+ assert logits is not None
+ exact, token_acc = correctness(logits, expanded["labels"])
+ base_bsz = batch["inputs"].shape[0]
+ return exact.view(base_bsz, rollouts), token_acc.view(base_bsz, rollouts)
+
+
+def summarize_sigma(exact: torch.Tensor, token_acc: torch.Tensor) -> dict[str, float]:
+ correct_counts = exact.float().sum(dim=1)
+ rollouts = exact.shape[1]
+ return {
+ "mean_rollout_exact": exact.float().mean().item(),
+ "mean_rollout_token_acc": token_acc.mean().item(),
+ "pass_at_k": exact.any(dim=1).float().mean().item(),
+ "all_k": exact.all(dim=1).float().mean().item(),
+ "correct_count_mean": correct_counts.mean().item(),
+ "correct_count_std": correct_counts.std(unbiased=False).item(),
+ "correct_count_q10": torch.quantile(correct_counts, 0.10).item(),
+ "correct_count_q50": torch.quantile(correct_counts, 0.50).item(),
+ "correct_count_q90": torch.quantile(correct_counts, 0.90).item(),
+ "zero_frac": (correct_counts == 0).float().mean().item(),
+ "full_frac": (correct_counts == rollouts).float().mean().item(),
+ }
+
+
+def write_summary(path: Path, rows: list[dict[str, Any]]) -> None:
+ keys = list(rows[0])
+ for row in rows[1:]:
+ for key in row:
+ if key not in keys:
+ keys.append(key)
+ with path.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=keys)
+ writer.writeheader()
+ writer.writerows(rows)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt-root", required=True)
+ parser.add_argument("--ckpt-name", required=True)
+ parser.add_argument("--label", required=True)
+ parser.add_argument("--n-samples", type=int, default=2000)
+ parser.add_argument("--batch-size", type=int, default=32,
+ help="Number of original problems per batch; expanded batch is batch_size * rollouts.")
+ parser.add_argument("--rollouts", type=int, default=8)
+ parser.add_argument("--sigmas", default="0,3e-5,1e-4,3e-4,1e-3,3e-3,1e-2,3e-2")
+ parser.add_argument("--perturb", choices=["h", "l", "both"], default="both")
+ parser.add_argument("--noise-distribution", choices=["gaussian", "uniform"], default="gaussian")
+ parser.add_argument("--seed", type=int, default=20260605)
+ parser.add_argument("--out-prefix", required=True)
+ args = parser.parse_args()
+
+ device = "cuda"
+ sigmas = parse_float_list(args.sigmas)
+ torch.manual_seed(args.seed)
+ generator = torch.Generator(device=device).manual_seed(args.seed + 101)
+
+ model, cfg, data_path = load_model(Path(args.ckpt_root), args.ckpt_name, device)
+ samples = load_test_samples(data_path, args.n_samples, args.seed)
+ n = len(samples["inputs"])
+ print(
+ f"[run] label={args.label} n={n} rollouts={args.rollouts} "
+ f"batch={args.batch_size} sigmas={sigmas}",
+ flush=True,
+ )
+
+ rows: list[dict[str, Any]] = []
+ all_exact = []
+ all_token = []
+ for sigma in sigmas:
+ exact_parts = []
+ token_parts = []
+ for start in range(0, n, args.batch_size):
+ end = min(start + args.batch_size, n)
+ batch = batch_slice(samples, start, end, device)
+ exact, token_acc = eval_sigma(
+ model=model,
+ batch=batch,
+ sigma=sigma,
+ rollouts=args.rollouts,
+ perturb=args.perturb,
+ distribution=args.noise_distribution,
+ generator=generator,
+ )
+ exact_parts.append(exact.cpu())
+ token_parts.append(token_acc.cpu())
+ if end == n or (end // args.batch_size) % 10 == 0:
+ print(f" sigma={sigma:g} [{end}/{n}]", flush=True)
+ exact_all = torch.cat(exact_parts, dim=0)
+ token_all = torch.cat(token_parts, dim=0)
+ row: dict[str, Any] = {
+ "label": args.label,
+ "sigma": sigma,
+ "n_samples": n,
+ "rollouts": args.rollouts,
+ "ckpt_root": str(Path(args.ckpt_root)),
+ "ckpt_name": args.ckpt_name,
+ "perturb": args.perturb,
+ "noise_distribution": args.noise_distribution,
+ **summarize_sigma(exact_all, token_all),
+ }
+ rows.append(row)
+ all_exact.append(exact_all.numpy())
+ all_token.append(token_all.numpy())
+ print(
+ f" sigma={sigma:g} mean={row['mean_rollout_exact']:.4f} "
+ f"pass@K={row['pass_at_k']:.4f} allK={row['all_k']:.4f}",
+ flush=True,
+ )
+
+ out_prefix = Path(args.out_prefix)
+ out_prefix.parent.mkdir(parents=True, exist_ok=True)
+ write_summary(out_prefix.with_suffix(".summary.csv"), rows)
+ meta = {
+ "args": vars(args),
+ "data_path": str(data_path),
+ "config_global_batch_size": cfg.get("global_batch_size"),
+ "sigmas": sigmas,
+ "n_samples": n,
+ }
+ out_prefix.with_suffix(".meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True))
+ np.savez_compressed(
+ out_prefix.with_suffix(".npz"),
+ idx=samples["idx"],
+ sigmas=np.asarray(sigmas, dtype=np.float32),
+ exact=np.stack(all_exact, axis=0),
+ token_acc=np.stack(all_token, axis=0),
+ meta_json=np.asarray(json.dumps(meta, sort_keys=True)),
+ )
+ print(f"[done] {out_prefix}.summary.csv", flush=True)
+
+
+if __name__ == "__main__":
+ main()