From 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sat, 13 Jun 2026 12:35:36 -0500 Subject: rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipeline Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 --- research/flossing/initial_perturb_robustness.py | 286 ++++++++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 research/flossing/initial_perturb_robustness.py (limited to 'research/flossing/initial_perturb_robustness.py') diff --git a/research/flossing/initial_perturb_robustness.py b/research/flossing/initial_perturb_robustness.py new file mode 100644 index 0000000..e652080 --- /dev/null +++ b/research/flossing/initial_perturb_robustness.py @@ -0,0 +1,286 @@ +"""Inference-time robustness to initial recurrent-state perturbations. + +This tests whether trajectory-perturbation training enlarged the correct +attractor basin. Unlike PTRM-style rollout noise, the perturbation is applied +once after resetting z_H/z_L, matching the training augmentation mechanism. +""" +from __future__ import annotations + +import argparse +import csv +import json +import math +import sys +from dataclasses import replace +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import yaml + + +TRM_DIR = Path("/home/yurenh2/rrm/trm") +sys.path.insert(0, str(TRM_DIR)) + +from models.recursive_reasoning.trm import ( # noqa: E402 + TinyRecursiveReasoningModel_ACTV1, + TinyRecursiveReasoningModel_ACTV1InnerCarry, +) + + +IGNORE_LABEL_ID = -100 + + +def parse_float_list(text: str) -> list[float]: + return [float(x.strip()) for x in text.split(",") if x.strip()] + + +def load_model(ckpt_root: Path, ckpt_name: str, device: str): + cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) + data_path = Path(cfg.get("data_path") or cfg["data_paths"][0]) + train_meta = json.loads((data_path / "train" / "dataset.json").read_text()) + + arch_cfg = dict(cfg["arch"]) + arch_cfg.update( + batch_size=cfg["global_batch_size"], + seq_len=train_meta["seq_len"], + vocab_size=train_meta["vocab_size"], + num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], + causal=False, + ) + + model = TinyRecursiveReasoningModel_ACTV1(arch_cfg) + state = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) + stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in state.items()} + missing, unexpected = model.load_state_dict(stripped, strict=False) + print(f"[load] {ckpt_root.name}/{ckpt_name} missing={len(missing)} unexpected={len(unexpected)}", flush=True) + if missing[:4]: + print(f"[load] sample missing: {missing[:4]}", flush=True) + if unexpected[:4]: + print(f"[load] sample unexpected: {unexpected[:4]}", flush=True) + model.to(device).eval() + return model, cfg, data_path + + +def load_test_samples(data_path: Path, n_samples: int, seed: int): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "test" / "all__inputs.npy") + labels = np.load(data_path / "test" / "all__labels.npy") + puzzle_ids = np.load(data_path / "test" / "all__puzzle_identifiers.npy") + + n = min(n_samples, len(inputs)) + idx = rng.choice(len(inputs), size=n, replace=False) + return { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), + "labels": torch.from_numpy(labels[idx].astype(np.int32)), + "puzzle_identifiers": torch.from_numpy(puzzle_ids[idx].astype(np.int32)), + "idx": idx, + } + + +def batch_slice(samples: dict[str, Any], start: int, end: int, device: str): + return { + k: v[start:end].to(device, non_blocking=True) + for k, v in samples.items() + if k in ("inputs", "labels", "puzzle_identifiers") + } + + +def repeat_batch(batch: dict[str, torch.Tensor], repeats: int): + if repeats == 1: + return batch + return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()} + + +def sample_unit_noise_like(tensor: torch.Tensor, generator: torch.Generator, distribution: str): + if distribution == "uniform": + noise = 2.0 * torch.rand(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) - 1.0 + noise = noise * math.sqrt(3.0) + else: + noise = torch.randn(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) + return noise.to(tensor.dtype) + + +def apply_initial_noise( + inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, + sigma: float, + perturb: str, + generator: torch.Generator, + distribution: str, +): + if sigma <= 0: + return inner + z_h, z_l = inner.z_H, inner.z_L + if perturb in ("h", "both"): + z_h = z_h + sigma * sample_unit_noise_like(z_h, generator, distribution) + if perturb in ("l", "both"): + z_l = z_l + sigma * sample_unit_noise_like(z_l, generator, distribution) + return replace(inner, z_H=z_h, z_L=z_l) + + +def correctness(logits: torch.Tensor, labels: torch.Tensor): + preds = logits.argmax(dim=-1) + mask = labels != IGNORE_LABEL_ID + exact = torch.where(mask, preds == labels, True).all(dim=-1) + denom = mask.sum(-1).clamp_min(1) + token_acc = ((preds == labels) & mask).sum(-1).float() / denom.float() + return exact, token_acc + + +@torch.inference_mode() +def eval_sigma( + model, + batch: dict[str, torch.Tensor], + sigma: float, + rollouts: int, + perturb: str, + distribution: str, + generator: torch.Generator, +): + expanded = repeat_batch(batch, rollouts) + total = expanded["inputs"].shape[0] + with torch.device(expanded["inputs"].device): + carry = model.initial_carry(expanded) + reset = torch.ones(total, device=expanded["inputs"].device, dtype=torch.bool) + inner = model.inner.reset_carry(reset, carry.inner_carry) + inner = apply_initial_noise(inner, sigma, perturb, generator, distribution) + + logits = None + for _ in range(model.config.halt_max_steps): + inner, logits, _q = model.inner(inner, expanded) + + assert logits is not None + exact, token_acc = correctness(logits, expanded["labels"]) + base_bsz = batch["inputs"].shape[0] + return exact.view(base_bsz, rollouts), token_acc.view(base_bsz, rollouts) + + +def summarize_sigma(exact: torch.Tensor, token_acc: torch.Tensor) -> dict[str, float]: + correct_counts = exact.float().sum(dim=1) + rollouts = exact.shape[1] + return { + "mean_rollout_exact": exact.float().mean().item(), + "mean_rollout_token_acc": token_acc.mean().item(), + "pass_at_k": exact.any(dim=1).float().mean().item(), + "all_k": exact.all(dim=1).float().mean().item(), + "correct_count_mean": correct_counts.mean().item(), + "correct_count_std": correct_counts.std(unbiased=False).item(), + "correct_count_q10": torch.quantile(correct_counts, 0.10).item(), + "correct_count_q50": torch.quantile(correct_counts, 0.50).item(), + "correct_count_q90": torch.quantile(correct_counts, 0.90).item(), + "zero_frac": (correct_counts == 0).float().mean().item(), + "full_frac": (correct_counts == rollouts).float().mean().item(), + } + + +def write_summary(path: Path, rows: list[dict[str, Any]]) -> None: + keys = list(rows[0]) + for row in rows[1:]: + for key in row: + if key not in keys: + keys.append(key) + with path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=keys) + writer.writeheader() + writer.writerows(rows) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt-root", required=True) + parser.add_argument("--ckpt-name", required=True) + parser.add_argument("--label", required=True) + parser.add_argument("--n-samples", type=int, default=2000) + parser.add_argument("--batch-size", type=int, default=32, + help="Number of original problems per batch; expanded batch is batch_size * rollouts.") + parser.add_argument("--rollouts", type=int, default=8) + parser.add_argument("--sigmas", default="0,3e-5,1e-4,3e-4,1e-3,3e-3,1e-2,3e-2") + parser.add_argument("--perturb", choices=["h", "l", "both"], default="both") + parser.add_argument("--noise-distribution", choices=["gaussian", "uniform"], default="gaussian") + parser.add_argument("--seed", type=int, default=20260605) + parser.add_argument("--out-prefix", required=True) + args = parser.parse_args() + + device = "cuda" + sigmas = parse_float_list(args.sigmas) + torch.manual_seed(args.seed) + generator = torch.Generator(device=device).manual_seed(args.seed + 101) + + model, cfg, data_path = load_model(Path(args.ckpt_root), args.ckpt_name, device) + samples = load_test_samples(data_path, args.n_samples, args.seed) + n = len(samples["inputs"]) + print( + f"[run] label={args.label} n={n} rollouts={args.rollouts} " + f"batch={args.batch_size} sigmas={sigmas}", + flush=True, + ) + + rows: list[dict[str, Any]] = [] + all_exact = [] + all_token = [] + for sigma in sigmas: + exact_parts = [] + token_parts = [] + for start in range(0, n, args.batch_size): + end = min(start + args.batch_size, n) + batch = batch_slice(samples, start, end, device) + exact, token_acc = eval_sigma( + model=model, + batch=batch, + sigma=sigma, + rollouts=args.rollouts, + perturb=args.perturb, + distribution=args.noise_distribution, + generator=generator, + ) + exact_parts.append(exact.cpu()) + token_parts.append(token_acc.cpu()) + if end == n or (end // args.batch_size) % 10 == 0: + print(f" sigma={sigma:g} [{end}/{n}]", flush=True) + exact_all = torch.cat(exact_parts, dim=0) + token_all = torch.cat(token_parts, dim=0) + row: dict[str, Any] = { + "label": args.label, + "sigma": sigma, + "n_samples": n, + "rollouts": args.rollouts, + "ckpt_root": str(Path(args.ckpt_root)), + "ckpt_name": args.ckpt_name, + "perturb": args.perturb, + "noise_distribution": args.noise_distribution, + **summarize_sigma(exact_all, token_all), + } + rows.append(row) + all_exact.append(exact_all.numpy()) + all_token.append(token_all.numpy()) + print( + f" sigma={sigma:g} mean={row['mean_rollout_exact']:.4f} " + f"pass@K={row['pass_at_k']:.4f} allK={row['all_k']:.4f}", + flush=True, + ) + + out_prefix = Path(args.out_prefix) + out_prefix.parent.mkdir(parents=True, exist_ok=True) + write_summary(out_prefix.with_suffix(".summary.csv"), rows) + meta = { + "args": vars(args), + "data_path": str(data_path), + "config_global_batch_size": cfg.get("global_batch_size"), + "sigmas": sigmas, + "n_samples": n, + } + out_prefix.with_suffix(".meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True)) + np.savez_compressed( + out_prefix.with_suffix(".npz"), + idx=samples["idx"], + sigmas=np.asarray(sigmas, dtype=np.float32), + exact=np.stack(all_exact, axis=0), + token_acc=np.stack(all_token, axis=0), + meta_json=np.asarray(json.dumps(meta, sort_keys=True)), + ) + print(f"[done] {out_prefix}.summary.csv", flush=True) + + +if __name__ == "__main__": + main() -- cgit v1.2.3