diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/late_perturb_robustness.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/late_perturb_robustness.py')
| -rw-r--r-- | research/flossing/late_perturb_robustness.py | 323 |
1 files changed, 323 insertions, 0 deletions
diff --git a/research/flossing/late_perturb_robustness.py b/research/flossing/late_perturb_robustness.py new file mode 100644 index 0000000..7ad2ba5 --- /dev/null +++ b/research/flossing/late_perturb_robustness.py @@ -0,0 +1,323 @@ +"""Inference-time robustness to late recurrent-state perturbations. + +For each sample, run the model cleanly for `perturb_after` inner iterations, +perturb z_H/z_L once, then continue the deterministic recurrent rollout. This +stress-tests the attractor basin along the inference trajectory rather than +only at the initial recurrent state. +""" +from __future__ import annotations + +import argparse +import csv +import json +import math +import sys +from dataclasses import replace +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import yaml + + +TRM_DIR = Path("/home/yurenh2/rrm/trm") +sys.path.insert(0, str(TRM_DIR)) + +from models.recursive_reasoning.trm import ( # noqa: E402 + TinyRecursiveReasoningModel_ACTV1, + TinyRecursiveReasoningModel_ACTV1InnerCarry, +) + + +IGNORE_LABEL_ID = -100 + + +def parse_float_list(text: str) -> list[float]: + return [float(x.strip()) for x in text.split(",") if x.strip()] + + +def parse_int_list(text: str) -> list[int]: + return [int(x.strip()) for x in text.split(",") if x.strip()] + + +def is_zero(value: float) -> bool: + return abs(value) <= 1e-12 + + +def load_model(ckpt_root: Path, ckpt_name: str, device: str): + cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) + data_path = Path(cfg.get("data_path") or cfg["data_paths"][0]) + train_meta = json.loads((data_path / "train" / "dataset.json").read_text()) + + arch_cfg = dict(cfg["arch"]) + arch_cfg.update( + batch_size=cfg["global_batch_size"], + seq_len=train_meta["seq_len"], + vocab_size=train_meta["vocab_size"], + num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], + causal=False, + ) + + model = TinyRecursiveReasoningModel_ACTV1(arch_cfg) + state = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) + stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in state.items()} + missing, unexpected = model.load_state_dict(stripped, strict=False) + print(f"[load] {ckpt_root.name}/{ckpt_name} missing={len(missing)} unexpected={len(unexpected)}", flush=True) + model.to(device).eval() + return model, cfg, data_path + + +def load_test_samples(data_path: Path, n_samples: int, seed: int): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "test" / "all__inputs.npy") + labels = np.load(data_path / "test" / "all__labels.npy") + puzzle_ids = np.load(data_path / "test" / "all__puzzle_identifiers.npy") + + n = min(n_samples, len(inputs)) + idx = rng.choice(len(inputs), size=n, replace=False) + return { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), + "labels": torch.from_numpy(labels[idx].astype(np.int32)), + "puzzle_identifiers": torch.from_numpy(puzzle_ids[idx].astype(np.int32)), + "idx": idx, + } + + +def batch_slice(samples: dict[str, Any], start: int, end: int, device: str): + return { + k: v[start:end].to(device, non_blocking=True) + for k, v in samples.items() + if k in ("inputs", "labels", "puzzle_identifiers") + } + + +def repeat_batch(batch: dict[str, torch.Tensor], repeats: int): + if repeats == 1: + return batch + return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()} + + +def sample_unit_noise_like(tensor: torch.Tensor, generator: torch.Generator, distribution: str): + if distribution == "uniform": + noise = 2.0 * torch.rand(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) - 1.0 + noise = noise * math.sqrt(3.0) + else: + noise = torch.randn(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) + return noise.to(tensor.dtype) + + +def apply_noise( + inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, + sigma: float, + perturb: str, + generator: torch.Generator, + distribution: str, +): + if sigma <= 0: + return inner + z_h, z_l = inner.z_H, inner.z_L + if perturb in ("h", "both"): + z_h = z_h + sigma * sample_unit_noise_like(z_h, generator, distribution) + if perturb in ("l", "both"): + z_l = z_l + sigma * sample_unit_noise_like(z_l, generator, distribution) + return replace(inner, z_H=z_h, z_L=z_l) + + +def correctness(logits: torch.Tensor, labels: torch.Tensor): + preds = logits.argmax(dim=-1) + mask = labels != IGNORE_LABEL_ID + exact = torch.where(mask, preds == labels, True).all(dim=-1) + denom = mask.sum(-1).clamp_min(1) + token_acc = ((preds == labels) & mask).sum(-1).float() / denom.float() + return exact, token_acc + + +@torch.inference_mode() +def eval_condition( + model, + batch: dict[str, torch.Tensor], + sigma: float, + rollouts: int, + perturb_after: int, + perturb: str, + distribution: str, + generator: torch.Generator, +): + expanded = repeat_batch(batch, rollouts) + total = expanded["inputs"].shape[0] + steps = model.config.halt_max_steps + warmup = min(max(perturb_after, 0), steps - 1) + + with torch.device(expanded["inputs"].device): + carry = model.initial_carry(expanded) + reset = torch.ones(total, device=expanded["inputs"].device, dtype=torch.bool) + inner = model.inner.reset_carry(reset, carry.inner_carry) + + logits = None + for _ in range(warmup): + inner, logits, _q = model.inner(inner, expanded) + inner = apply_noise(inner, sigma, perturb, generator, distribution) + for _ in range(warmup, steps): + inner, logits, _q = model.inner(inner, expanded) + + assert logits is not None + exact, token_acc = correctness(logits, expanded["labels"]) + base_bsz = batch["inputs"].shape[0] + return exact.view(base_bsz, rollouts), token_acc.view(base_bsz, rollouts) + + +def summarize(exact: torch.Tensor, token_acc: torch.Tensor, clean_exact: torch.Tensor) -> dict[str, float]: + correct_counts = exact.float().sum(dim=1) + rollouts = exact.shape[1] + clean_success = clean_exact.bool() + clean_fail = ~clean_success + out = { + "mean_rollout_exact": exact.float().mean().item(), + "mean_rollout_token_acc": token_acc.mean().item(), + "pass_at_k": exact.any(dim=1).float().mean().item(), + "all_k": exact.all(dim=1).float().mean().item(), + "correct_count_mean": correct_counts.mean().item(), + "correct_count_std": correct_counts.std(unbiased=False).item(), + "zero_frac": (correct_counts == 0).float().mean().item(), + "full_frac": (correct_counts == rollouts).float().mean().item(), + "clean_acc": clean_success.float().mean().item(), + } + if clean_success.any().item(): + out["retain_mean_on_clean_success"] = exact[clean_success].float().mean().item() + out["allK_on_clean_success"] = exact[clean_success].all(dim=1).float().mean().item() + else: + out["retain_mean_on_clean_success"] = float("nan") + out["allK_on_clean_success"] = float("nan") + if clean_fail.any().item(): + out["rescue_mean_on_clean_fail"] = exact[clean_fail].float().mean().item() + out["passK_on_clean_fail"] = exact[clean_fail].any(dim=1).float().mean().item() + else: + out["rescue_mean_on_clean_fail"] = float("nan") + out["passK_on_clean_fail"] = float("nan") + return out + + +def write_summary(path: Path, rows: list[dict[str, Any]]) -> None: + keys = list(rows[0]) + for row in rows[1:]: + for key in row: + if key not in keys: + keys.append(key) + with path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=keys) + writer.writeheader() + writer.writerows(rows) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt-root", required=True) + parser.add_argument("--ckpt-name", required=True) + parser.add_argument("--label", required=True) + parser.add_argument("--n-samples", type=int, default=3000) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--rollouts", type=int, default=8) + parser.add_argument("--sigmas", default="0,0.001,0.003,0.01,0.03,0.1") + parser.add_argument("--perturb-afters", default="0,4,8,12,15") + parser.add_argument("--perturb", choices=["h", "l", "both"], default="both") + parser.add_argument("--noise-distribution", choices=["gaussian", "uniform"], default="gaussian") + parser.add_argument("--seed", type=int, default=20260606) + parser.add_argument("--out-prefix", required=True) + args = parser.parse_args() + + device = "cuda" + sigmas = parse_float_list(args.sigmas) + perturb_afters = parse_int_list(args.perturb_afters) + torch.manual_seed(args.seed) + generator = torch.Generator(device=device).manual_seed(args.seed + 101) + + model, cfg, data_path = load_model(Path(args.ckpt_root), args.ckpt_name, device) + samples = load_test_samples(data_path, args.n_samples, args.seed) + n = len(samples["inputs"]) + print( + f"[run] label={args.label} n={n} rollouts={args.rollouts} batch={args.batch_size} " + f"afters={perturb_afters} sigmas={sigmas}", + flush=True, + ) + + rows: list[dict[str, Any]] = [] + all_exact = [] + all_token = [] + if not any(is_zero(s) for s in sigmas): + sigmas = [0.0] + sigmas + + for after in perturb_afters: + clean_exact = None + pending: list[tuple[dict[str, Any], torch.Tensor, torch.Tensor]] = [] + for sigma in sigmas: + exact_parts = [] + token_parts = [] + for start in range(0, n, args.batch_size): + end = min(start + args.batch_size, n) + batch = batch_slice(samples, start, end, device) + exact, token_acc = eval_condition( + model, batch, sigma, args.rollouts, after, args.perturb, args.noise_distribution, generator + ) + exact_parts.append(exact.cpu()) + token_parts.append(token_acc.cpu()) + if end == n or (end // args.batch_size) % 10 == 0: + print(f" after={after} sigma={sigma:g} [{end}/{n}]", flush=True) + exact_all = torch.cat(exact_parts, dim=0) + token_all = torch.cat(token_parts, dim=0) + if is_zero(sigma): + clean_exact = exact_all[:, 0].clone() + print(f" after={after} clean grouping done clean_acc={clean_exact.float().mean().item():.4f}", flush=True) + if clean_exact is None: + pending.append(({"sigma": sigma}, exact_all, token_all)) + continue + row: dict[str, Any] = { + "label": args.label, + "perturb_after": after, + "sigma": sigma, + "n_samples": n, + "rollouts": args.rollouts, + "ckpt_root": str(Path(args.ckpt_root)), + "ckpt_name": args.ckpt_name, + "perturb": args.perturb, + "noise_distribution": args.noise_distribution, + **summarize(exact_all, token_all, clean_exact), + } + rows.append(row) + all_exact.append(exact_all.numpy()) + all_token.append(token_all.numpy()) + print( + f" after={after} sigma={sigma:g} mean={row['mean_rollout_exact']:.4f} " + f"retain={row['retain_mean_on_clean_success']:.4f} " + f"rescue={row['rescue_mean_on_clean_fail']:.4f}", + flush=True, + ) + if pending: + raise RuntimeError("sigma=0 must be evaluated before nonzero sigmas") + + out_prefix = Path(args.out_prefix) + out_prefix.parent.mkdir(parents=True, exist_ok=True) + write_summary(out_prefix.with_suffix(".summary.csv"), rows) + meta = { + "args": vars(args), + "data_path": str(data_path), + "config_global_batch_size": cfg.get("global_batch_size"), + "sigmas": sigmas, + "perturb_afters": perturb_afters, + "n_samples": n, + } + out_prefix.with_suffix(".meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True)) + np.savez_compressed( + out_prefix.with_suffix(".npz"), + idx=samples["idx"], + sigmas=np.asarray(sigmas, dtype=np.float32), + perturb_afters=np.asarray(perturb_afters, dtype=np.int32), + exact=np.stack(all_exact, axis=0), + token_acc=np.stack(all_token, axis=0), + meta_json=np.asarray(json.dumps(meta, sort_keys=True)), + ) + print(f"[done] {out_prefix}.summary.csv", flush=True) + + +if __name__ == "__main__": + main() |
