diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/directional_lyap_perturb_robustness.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/directional_lyap_perturb_robustness.py')
| -rw-r--r-- | research/flossing/directional_lyap_perturb_robustness.py | 428 |
1 files changed, 428 insertions, 0 deletions
diff --git a/research/flossing/directional_lyap_perturb_robustness.py b/research/flossing/directional_lyap_perturb_robustness.py new file mode 100644 index 0000000..c0a803c --- /dev/null +++ b/research/flossing/directional_lyap_perturb_robustness.py @@ -0,0 +1,428 @@ +"""Directional late-state perturbation using finite-difference Lyapunov search. + +At a chosen recurrent step, sample several unit tangent directions, propagate +small shadow trajectories through the remaining deterministic dynamics, choose +the direction with maximal final hidden-state expansion, then perturb along +that selected direction with +/- sigma and measure answer robustness. + +This is a practical finite-difference proxy for perturbing along a local top +Lyapunov direction without paying exact JVP costs. +""" +from __future__ import annotations + +import argparse +import csv +import json +import math +import sys +from dataclasses import replace +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import yaml + + +TRM_DIR = Path("/home/yurenh2/rrm/trm") +sys.path.insert(0, str(TRM_DIR)) + +from models.recursive_reasoning.trm import ( # noqa: E402 + TinyRecursiveReasoningModel_ACTV1, + TinyRecursiveReasoningModel_ACTV1InnerCarry, +) + + +IGNORE_LABEL_ID = -100 + + +def parse_float_list(text: str) -> list[float]: + return [float(x.strip()) for x in text.split(",") if x.strip()] + + +def parse_int_list(text: str) -> list[int]: + return [int(x.strip()) for x in text.split(",") if x.strip()] + + +def load_model(ckpt_root: Path, ckpt_name: str, device: str): + cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) + data_path = Path(cfg.get("data_path") or cfg["data_paths"][0]) + train_meta = json.loads((data_path / "train" / "dataset.json").read_text()) + + arch_cfg = dict(cfg["arch"]) + arch_cfg.update( + batch_size=cfg["global_batch_size"], + seq_len=train_meta["seq_len"], + vocab_size=train_meta["vocab_size"], + num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], + causal=False, + ) + + model = TinyRecursiveReasoningModel_ACTV1(arch_cfg) + state = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) + stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in state.items()} + missing, unexpected = model.load_state_dict(stripped, strict=False) + print(f"[load] {ckpt_root.name}/{ckpt_name} missing={len(missing)} unexpected={len(unexpected)}", flush=True) + model.to(device).eval() + return model, cfg, data_path + + +def load_test_samples(data_path: Path, n_samples: int, seed: int): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "test" / "all__inputs.npy") + labels = np.load(data_path / "test" / "all__labels.npy") + puzzle_ids = np.load(data_path / "test" / "all__puzzle_identifiers.npy") + + n = min(n_samples, len(inputs)) + idx = rng.choice(len(inputs), size=n, replace=False) + return { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), + "labels": torch.from_numpy(labels[idx].astype(np.int32)), + "puzzle_identifiers": torch.from_numpy(puzzle_ids[idx].astype(np.int32)), + "idx": idx, + } + + +def batch_slice(samples: dict[str, Any], start: int, end: int, device: str): + return { + k: v[start:end].to(device, non_blocking=True) + for k, v in samples.items() + if k in ("inputs", "labels", "puzzle_identifiers") + } + + +def repeat_batch(batch: dict[str, torch.Tensor], repeats: int): + if repeats == 1: + return batch + return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()} + + +def cat_batches(a: dict[str, torch.Tensor], b: dict[str, torch.Tensor]): + return {k: torch.cat([a[k], b[k]], dim=0) for k in a} + + +def cat_inner(a: TinyRecursiveReasoningModel_ACTV1InnerCarry, b: TinyRecursiveReasoningModel_ACTV1InnerCarry): + return TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=torch.cat([a.z_H, b.z_H], dim=0), + z_L=torch.cat([a.z_L, b.z_L], dim=0), + ) + + +def split_inner(inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, n_main: int): + return ( + TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=inner.z_H[:n_main].contiguous(), + z_L=inner.z_L[:n_main].contiguous(), + ), + TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=inner.z_H[n_main:].contiguous(), + z_L=inner.z_L[n_main:].contiguous(), + ), + ) + + +def correctness(logits: torch.Tensor, labels: torch.Tensor): + preds = logits.argmax(dim=-1) + mask = labels != IGNORE_LABEL_ID + exact = torch.where(mask, preds == labels, True).all(dim=-1) + denom = mask.sum(-1).clamp_min(1) + token_acc = ((preds == labels) & mask).sum(-1).float() / denom.float() + return exact, token_acc + + +def rand_unit_dirs( + inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, + candidates: int, + generator: torch.Generator, +): + bsz = inner.z_H.shape[0] + h_dirs = torch.randn( + (bsz, candidates) + tuple(inner.z_H.shape[1:]), + device=inner.z_H.device, + dtype=torch.float32, + generator=generator, + ) + l_dirs = torch.randn( + (bsz, candidates) + tuple(inner.z_L.shape[1:]), + device=inner.z_L.device, + dtype=torch.float32, + generator=generator, + ) + norm = torch.sqrt(h_dirs.flatten(2).square().sum(-1) + l_dirs.flatten(2).square().sum(-1)).clamp_min(1e-30) + h_view = (bsz, candidates) + (1,) * (h_dirs.ndim - 2) + l_view = (bsz, candidates) + (1,) * (l_dirs.ndim - 2) + return (h_dirs / norm.view(h_view)).to(inner.z_H.dtype), (l_dirs / norm.view(l_view)).to(inner.z_L.dtype) + + +def make_shadow_inner( + inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, + h_dirs: torch.Tensor, + l_dirs: torch.Tensor, + eps: float, +): + bsz, candidates = h_dirs.shape[:2] + z_h = inner.z_H[:, None] + eps * h_dirs + z_l = inner.z_L[:, None] + eps * l_dirs + return TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=z_h.reshape((bsz * candidates,) + tuple(inner.z_H.shape[1:])).detach(), + z_L=z_l.reshape((bsz * candidates,) + tuple(inner.z_L.shape[1:])).detach(), + ) + + +def separation( + main: TinyRecursiveReasoningModel_ACTV1InnerCarry, + shadow: TinyRecursiveReasoningModel_ACTV1InnerCarry, + candidates: int, +): + bsz = main.z_H.shape[0] + sh = shadow.z_H.reshape((bsz, candidates) + tuple(main.z_H.shape[1:])).float() + sl = shadow.z_L.reshape((bsz, candidates) + tuple(main.z_L.shape[1:])).float() + dh = (sh - main.z_H[:, None].float()).flatten(2) + dl = (sl - main.z_L[:, None].float()).flatten(2) + return torch.sqrt(dh.square().sum(-1) + dl.square().sum(-1)).clamp_min(1e-30) + + +def gather_dirs(h_dirs: torch.Tensor, l_dirs: torch.Tensor, idx: torch.Tensor): + bsz = h_dirs.shape[0] + arange = torch.arange(bsz, device=h_dirs.device) + return h_dirs[arange, idx].contiguous(), l_dirs[arange, idx].contiguous() + + +def perturb_inner( + inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, + h_dir: torch.Tensor, + l_dir: torch.Tensor, + sigma: float, + sign: float, +): + return replace( + inner, + z_H=inner.z_H + (sign * sigma) * h_dir.to(inner.z_H.dtype), + z_L=inner.z_L + (sign * sigma) * l_dir.to(inner.z_L.dtype), + ) + + +@torch.inference_mode() +def warmup_inner(model, batch: dict[str, torch.Tensor], after: int): + bsz = batch["inputs"].shape[0] + with torch.device(batch["inputs"].device): + carry = model.initial_carry(batch) + reset = torch.ones(bsz, device=batch["inputs"].device, dtype=torch.bool) + inner = model.inner.reset_carry(reset, carry.inner_carry) + logits = None + for _ in range(after): + inner, logits, _q = model.inner(inner, batch) + return inner, logits + + +@torch.inference_mode() +def search_direction( + model, + warm: TinyRecursiveReasoningModel_ACTV1InnerCarry, + batch: dict[str, torch.Tensor], + after: int, + candidates: int, + fd_eps: float, + generator: torch.Generator, +): + bsz = batch["inputs"].shape[0] + remaining = model.config.halt_max_steps - after + h_dirs, l_dirs = rand_unit_dirs(warm, candidates, generator) + main = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=warm.z_H.detach(), z_L=warm.z_L.detach()) + shadow = make_shadow_inner(main, h_dirs, l_dirs, fd_eps) + combined = cat_inner(main, shadow) + combined_batch = cat_batches(batch, repeat_batch(batch, candidates)) + + logits = None + for _ in range(remaining): + combined, logits, _q = model.inner(combined, combined_batch) + assert logits is not None + main_final, shadow_final = split_inner(combined, bsz) + main_logits = logits[:bsz] + sep = separation(main_final, shadow_final, candidates) + best_idx = sep.argmax(dim=1) + best_sep = sep.gather(1, best_idx[:, None]).squeeze(1) + best_h, best_l = gather_dirs(h_dirs, l_dirs, best_idx) + growth = torch.log(best_sep / fd_eps).float() / max(remaining, 1) + clean_exact, clean_token = correctness(main_logits, batch["labels"]) + return best_h, best_l, growth, clean_exact, clean_token + + +@torch.inference_mode() +def eval_directional_sigma( + model, + warm: TinyRecursiveReasoningModel_ACTV1InnerCarry, + batch: dict[str, torch.Tensor], + after: int, + h_dir: torch.Tensor, + l_dir: torch.Tensor, + sigma: float, +): + remaining = model.config.halt_max_steps - after + plus = perturb_inner(warm, h_dir, l_dir, sigma, +1.0) + minus = perturb_inner(warm, h_dir, l_dir, sigma, -1.0) + inner = cat_inner(plus, minus) + combined_batch = cat_batches(batch, batch) + logits = None + for _ in range(remaining): + inner, logits, _q = model.inner(inner, combined_batch) + assert logits is not None + exact, token = correctness(logits, combined_batch["labels"]) + bsz = batch["inputs"].shape[0] + return exact.view(2, bsz).transpose(0, 1).contiguous(), token.view(2, bsz).transpose(0, 1).contiguous() + + +def summarize(exact: torch.Tensor, token: torch.Tensor, clean_exact: torch.Tensor, growth: torch.Tensor): + clean_success = clean_exact.bool() + clean_fail = ~clean_success + both = exact.all(dim=1) + either = exact.any(dim=1) + out = { + "clean_acc": clean_success.float().mean().item(), + "mean_sign_exact": exact.float().mean().item(), + "mean_sign_token_acc": token.mean().item(), + "worst_sign_exact": both.float().mean().item(), + "best_sign_exact": either.float().mean().item(), + "selected_growth_mean": growth.mean().item(), + "selected_growth_q90": torch.quantile(growth, 0.90).item(), + } + if clean_success.any().item(): + out["retain_mean_on_clean_success"] = exact[clean_success].float().mean().item() + out["retain_worst_on_clean_success"] = both[clean_success].float().mean().item() + else: + out["retain_mean_on_clean_success"] = float("nan") + out["retain_worst_on_clean_success"] = float("nan") + if clean_fail.any().item(): + out["rescue_mean_on_clean_fail"] = exact[clean_fail].float().mean().item() + out["rescue_best_on_clean_fail"] = either[clean_fail].float().mean().item() + else: + out["rescue_mean_on_clean_fail"] = float("nan") + out["rescue_best_on_clean_fail"] = float("nan") + return out + + +def write_summary(path: Path, rows: list[dict[str, Any]]) -> None: + keys = list(rows[0]) + for row in rows[1:]: + for key in row: + if key not in keys: + keys.append(key) + with path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=keys) + writer.writeheader() + writer.writerows(rows) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt-root", required=True) + parser.add_argument("--ckpt-name", required=True) + parser.add_argument("--label", required=True) + parser.add_argument("--n-samples", type=int, default=1000) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--candidates", type=int, default=8) + parser.add_argument("--fd-eps", type=float, default=1e-3) + parser.add_argument("--sigmas", default="0,0.001,0.003,0.01,0.03,0.1") + parser.add_argument("--perturb-afters", default="0,4,8,12") + parser.add_argument("--seed", type=int, default=20260607) + parser.add_argument("--out-prefix", required=True) + args = parser.parse_args() + + device = "cuda" + sigmas = parse_float_list(args.sigmas) + if not any(abs(s) <= 1e-12 for s in sigmas): + sigmas = [0.0] + sigmas + afters = parse_int_list(args.perturb_afters) + torch.manual_seed(args.seed) + generator = torch.Generator(device=device).manual_seed(args.seed + 17) + + model, cfg, data_path = load_model(Path(args.ckpt_root), args.ckpt_name, device) + samples = load_test_samples(data_path, args.n_samples, args.seed) + n = len(samples["inputs"]) + print( + f"[run] label={args.label} n={n} batch={args.batch_size} candidates={args.candidates} " + f"afters={afters} sigmas={sigmas}", + flush=True, + ) + + rows: list[dict[str, Any]] = [] + exact_store = [] + token_store = [] + growth_store = [] + for after in afters: + sigma_exact = {sigma: [] for sigma in sigmas} + sigma_token = {sigma: [] for sigma in sigmas} + growth_parts = [] + for start in range(0, n, args.batch_size): + end = min(start + args.batch_size, n) + batch = batch_slice(samples, start, end, device) + warm, _ = warmup_inner(model, batch, after) + h_dir, l_dir, growth, _search_clean_exact, _search_clean_token = search_direction( + model, warm, batch, after, args.candidates, args.fd_eps, generator + ) + growth_parts.append(growth.cpu()) + for sigma in sigmas: + exact, token = eval_directional_sigma(model, warm, batch, after, h_dir, l_dir, sigma) + sigma_exact[sigma].append(exact.cpu()) + sigma_token[sigma].append(token.cpu()) + if end == n or (end // args.batch_size) % 10 == 0: + print(f" after={after} [{end}/{n}]", flush=True) + + growth_all = torch.cat(growth_parts, dim=0) + growth_store.append(growth_all.numpy()) + exact_by_sigma = {sigma: torch.cat(sigma_exact[sigma], dim=0) for sigma in sigmas} + token_by_sigma = {sigma: torch.cat(sigma_token[sigma], dim=0) for sigma in sigmas} + clean_all = exact_by_sigma[0.0][:, 0].clone() + print(f" after={after} clean grouping done clean_acc={clean_all.float().mean().item():.4f}", flush=True) + for sigma in sigmas: + exact_all = exact_by_sigma[sigma] + token_all = token_by_sigma[sigma] + row: dict[str, Any] = { + "label": args.label, + "perturb_after": after, + "sigma": sigma, + "n_samples": n, + "candidates": args.candidates, + "fd_eps": args.fd_eps, + "ckpt_root": str(Path(args.ckpt_root)), + "ckpt_name": args.ckpt_name, + **summarize(exact_all, token_all, clean_all, growth_all), + } + rows.append(row) + exact_store.append(exact_all.numpy()) + token_store.append(token_all.numpy()) + print( + f" after={after} sigma={sigma:g} clean={row['clean_acc']:.4f} " + f"mean={row['mean_sign_exact']:.4f} worst={row['worst_sign_exact']:.4f} " + f"retain_worst={row['retain_worst_on_clean_success']:.4f} " + f"rescue_best={row['rescue_best_on_clean_fail']:.4f}", + flush=True, + ) + + out_prefix = Path(args.out_prefix) + out_prefix.parent.mkdir(parents=True, exist_ok=True) + write_summary(out_prefix.with_suffix(".summary.csv"), rows) + meta = { + "args": vars(args), + "data_path": str(data_path), + "config_global_batch_size": cfg.get("global_batch_size"), + "sigmas": sigmas, + "perturb_afters": afters, + "n_samples": n, + } + out_prefix.with_suffix(".meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True)) + np.savez_compressed( + out_prefix.with_suffix(".npz"), + idx=samples["idx"], + sigmas=np.asarray(sigmas, dtype=np.float32), + perturb_afters=np.asarray(afters, dtype=np.int32), + exact=np.stack(exact_store, axis=0), + token_acc=np.stack(token_store, axis=0), + selected_growth=np.stack(growth_store, axis=0), + meta_json=np.asarray(json.dumps(meta, sort_keys=True)), + ) + print(f"[done] {out_prefix}.summary.csv", flush=True) + + +if __name__ == "__main__": + main() |
