summaryrefslogtreecommitdiff
path: root/research/flossing/directional_lyap_perturb_robustness.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/directional_lyap_perturb_robustness.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/directional_lyap_perturb_robustness.py')
-rw-r--r--research/flossing/directional_lyap_perturb_robustness.py428
1 files changed, 428 insertions, 0 deletions
diff --git a/research/flossing/directional_lyap_perturb_robustness.py b/research/flossing/directional_lyap_perturb_robustness.py
new file mode 100644
index 0000000..c0a803c
--- /dev/null
+++ b/research/flossing/directional_lyap_perturb_robustness.py
@@ -0,0 +1,428 @@
+"""Directional late-state perturbation using finite-difference Lyapunov search.
+
+At a chosen recurrent step, sample several unit tangent directions, propagate
+small shadow trajectories through the remaining deterministic dynamics, choose
+the direction with maximal final hidden-state expansion, then perturb along
+that selected direction with +/- sigma and measure answer robustness.
+
+This is a practical finite-difference proxy for perturbing along a local top
+Lyapunov direction without paying exact JVP costs.
+"""
+from __future__ import annotations
+
+import argparse
+import csv
+import json
+import math
+import sys
+from dataclasses import replace
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+import yaml
+
+
+TRM_DIR = Path("/home/yurenh2/rrm/trm")
+sys.path.insert(0, str(TRM_DIR))
+
+from models.recursive_reasoning.trm import ( # noqa: E402
+ TinyRecursiveReasoningModel_ACTV1,
+ TinyRecursiveReasoningModel_ACTV1InnerCarry,
+)
+
+
+IGNORE_LABEL_ID = -100
+
+
+def parse_float_list(text: str) -> list[float]:
+ return [float(x.strip()) for x in text.split(",") if x.strip()]
+
+
+def parse_int_list(text: str) -> list[int]:
+ return [int(x.strip()) for x in text.split(",") if x.strip()]
+
+
+def load_model(ckpt_root: Path, ckpt_name: str, device: str):
+ cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text())
+ data_path = Path(cfg.get("data_path") or cfg["data_paths"][0])
+ train_meta = json.loads((data_path / "train" / "dataset.json").read_text())
+
+ arch_cfg = dict(cfg["arch"])
+ arch_cfg.update(
+ batch_size=cfg["global_batch_size"],
+ seq_len=train_meta["seq_len"],
+ vocab_size=train_meta["vocab_size"],
+ num_puzzle_identifiers=train_meta["num_puzzle_identifiers"],
+ causal=False,
+ )
+
+ model = TinyRecursiveReasoningModel_ACTV1(arch_cfg)
+ state = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True)
+ stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in state.items()}
+ missing, unexpected = model.load_state_dict(stripped, strict=False)
+ print(f"[load] {ckpt_root.name}/{ckpt_name} missing={len(missing)} unexpected={len(unexpected)}", flush=True)
+ model.to(device).eval()
+ return model, cfg, data_path
+
+
+def load_test_samples(data_path: Path, n_samples: int, seed: int):
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "test" / "all__inputs.npy")
+ labels = np.load(data_path / "test" / "all__labels.npy")
+ puzzle_ids = np.load(data_path / "test" / "all__puzzle_identifiers.npy")
+
+ n = min(n_samples, len(inputs))
+ idx = rng.choice(len(inputs), size=n, replace=False)
+ return {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)),
+ "puzzle_identifiers": torch.from_numpy(puzzle_ids[idx].astype(np.int32)),
+ "idx": idx,
+ }
+
+
+def batch_slice(samples: dict[str, Any], start: int, end: int, device: str):
+ return {
+ k: v[start:end].to(device, non_blocking=True)
+ for k, v in samples.items()
+ if k in ("inputs", "labels", "puzzle_identifiers")
+ }
+
+
+def repeat_batch(batch: dict[str, torch.Tensor], repeats: int):
+ if repeats == 1:
+ return batch
+ return {k: v.repeat_interleave(repeats, dim=0) for k, v in batch.items()}
+
+
+def cat_batches(a: dict[str, torch.Tensor], b: dict[str, torch.Tensor]):
+ return {k: torch.cat([a[k], b[k]], dim=0) for k in a}
+
+
+def cat_inner(a: TinyRecursiveReasoningModel_ACTV1InnerCarry, b: TinyRecursiveReasoningModel_ACTV1InnerCarry):
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
+ z_H=torch.cat([a.z_H, b.z_H], dim=0),
+ z_L=torch.cat([a.z_L, b.z_L], dim=0),
+ )
+
+
+def split_inner(inner: TinyRecursiveReasoningModel_ACTV1InnerCarry, n_main: int):
+ return (
+ TinyRecursiveReasoningModel_ACTV1InnerCarry(
+ z_H=inner.z_H[:n_main].contiguous(),
+ z_L=inner.z_L[:n_main].contiguous(),
+ ),
+ TinyRecursiveReasoningModel_ACTV1InnerCarry(
+ z_H=inner.z_H[n_main:].contiguous(),
+ z_L=inner.z_L[n_main:].contiguous(),
+ ),
+ )
+
+
+def correctness(logits: torch.Tensor, labels: torch.Tensor):
+ preds = logits.argmax(dim=-1)
+ mask = labels != IGNORE_LABEL_ID
+ exact = torch.where(mask, preds == labels, True).all(dim=-1)
+ denom = mask.sum(-1).clamp_min(1)
+ token_acc = ((preds == labels) & mask).sum(-1).float() / denom.float()
+ return exact, token_acc
+
+
+def rand_unit_dirs(
+ inner: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ candidates: int,
+ generator: torch.Generator,
+):
+ bsz = inner.z_H.shape[0]
+ h_dirs = torch.randn(
+ (bsz, candidates) + tuple(inner.z_H.shape[1:]),
+ device=inner.z_H.device,
+ dtype=torch.float32,
+ generator=generator,
+ )
+ l_dirs = torch.randn(
+ (bsz, candidates) + tuple(inner.z_L.shape[1:]),
+ device=inner.z_L.device,
+ dtype=torch.float32,
+ generator=generator,
+ )
+ norm = torch.sqrt(h_dirs.flatten(2).square().sum(-1) + l_dirs.flatten(2).square().sum(-1)).clamp_min(1e-30)
+ h_view = (bsz, candidates) + (1,) * (h_dirs.ndim - 2)
+ l_view = (bsz, candidates) + (1,) * (l_dirs.ndim - 2)
+ return (h_dirs / norm.view(h_view)).to(inner.z_H.dtype), (l_dirs / norm.view(l_view)).to(inner.z_L.dtype)
+
+
+def make_shadow_inner(
+ inner: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ h_dirs: torch.Tensor,
+ l_dirs: torch.Tensor,
+ eps: float,
+):
+ bsz, candidates = h_dirs.shape[:2]
+ z_h = inner.z_H[:, None] + eps * h_dirs
+ z_l = inner.z_L[:, None] + eps * l_dirs
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
+ z_H=z_h.reshape((bsz * candidates,) + tuple(inner.z_H.shape[1:])).detach(),
+ z_L=z_l.reshape((bsz * candidates,) + tuple(inner.z_L.shape[1:])).detach(),
+ )
+
+
+def separation(
+ main: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ shadow: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ candidates: int,
+):
+ bsz = main.z_H.shape[0]
+ sh = shadow.z_H.reshape((bsz, candidates) + tuple(main.z_H.shape[1:])).float()
+ sl = shadow.z_L.reshape((bsz, candidates) + tuple(main.z_L.shape[1:])).float()
+ dh = (sh - main.z_H[:, None].float()).flatten(2)
+ dl = (sl - main.z_L[:, None].float()).flatten(2)
+ return torch.sqrt(dh.square().sum(-1) + dl.square().sum(-1)).clamp_min(1e-30)
+
+
+def gather_dirs(h_dirs: torch.Tensor, l_dirs: torch.Tensor, idx: torch.Tensor):
+ bsz = h_dirs.shape[0]
+ arange = torch.arange(bsz, device=h_dirs.device)
+ return h_dirs[arange, idx].contiguous(), l_dirs[arange, idx].contiguous()
+
+
+def perturb_inner(
+ inner: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ h_dir: torch.Tensor,
+ l_dir: torch.Tensor,
+ sigma: float,
+ sign: float,
+):
+ return replace(
+ inner,
+ z_H=inner.z_H + (sign * sigma) * h_dir.to(inner.z_H.dtype),
+ z_L=inner.z_L + (sign * sigma) * l_dir.to(inner.z_L.dtype),
+ )
+
+
+@torch.inference_mode()
+def warmup_inner(model, batch: dict[str, torch.Tensor], after: int):
+ bsz = batch["inputs"].shape[0]
+ with torch.device(batch["inputs"].device):
+ carry = model.initial_carry(batch)
+ reset = torch.ones(bsz, device=batch["inputs"].device, dtype=torch.bool)
+ inner = model.inner.reset_carry(reset, carry.inner_carry)
+ logits = None
+ for _ in range(after):
+ inner, logits, _q = model.inner(inner, batch)
+ return inner, logits
+
+
+@torch.inference_mode()
+def search_direction(
+ model,
+ warm: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ batch: dict[str, torch.Tensor],
+ after: int,
+ candidates: int,
+ fd_eps: float,
+ generator: torch.Generator,
+):
+ bsz = batch["inputs"].shape[0]
+ remaining = model.config.halt_max_steps - after
+ h_dirs, l_dirs = rand_unit_dirs(warm, candidates, generator)
+ main = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=warm.z_H.detach(), z_L=warm.z_L.detach())
+ shadow = make_shadow_inner(main, h_dirs, l_dirs, fd_eps)
+ combined = cat_inner(main, shadow)
+ combined_batch = cat_batches(batch, repeat_batch(batch, candidates))
+
+ logits = None
+ for _ in range(remaining):
+ combined, logits, _q = model.inner(combined, combined_batch)
+ assert logits is not None
+ main_final, shadow_final = split_inner(combined, bsz)
+ main_logits = logits[:bsz]
+ sep = separation(main_final, shadow_final, candidates)
+ best_idx = sep.argmax(dim=1)
+ best_sep = sep.gather(1, best_idx[:, None]).squeeze(1)
+ best_h, best_l = gather_dirs(h_dirs, l_dirs, best_idx)
+ growth = torch.log(best_sep / fd_eps).float() / max(remaining, 1)
+ clean_exact, clean_token = correctness(main_logits, batch["labels"])
+ return best_h, best_l, growth, clean_exact, clean_token
+
+
+@torch.inference_mode()
+def eval_directional_sigma(
+ model,
+ warm: TinyRecursiveReasoningModel_ACTV1InnerCarry,
+ batch: dict[str, torch.Tensor],
+ after: int,
+ h_dir: torch.Tensor,
+ l_dir: torch.Tensor,
+ sigma: float,
+):
+ remaining = model.config.halt_max_steps - after
+ plus = perturb_inner(warm, h_dir, l_dir, sigma, +1.0)
+ minus = perturb_inner(warm, h_dir, l_dir, sigma, -1.0)
+ inner = cat_inner(plus, minus)
+ combined_batch = cat_batches(batch, batch)
+ logits = None
+ for _ in range(remaining):
+ inner, logits, _q = model.inner(inner, combined_batch)
+ assert logits is not None
+ exact, token = correctness(logits, combined_batch["labels"])
+ bsz = batch["inputs"].shape[0]
+ return exact.view(2, bsz).transpose(0, 1).contiguous(), token.view(2, bsz).transpose(0, 1).contiguous()
+
+
+def summarize(exact: torch.Tensor, token: torch.Tensor, clean_exact: torch.Tensor, growth: torch.Tensor):
+ clean_success = clean_exact.bool()
+ clean_fail = ~clean_success
+ both = exact.all(dim=1)
+ either = exact.any(dim=1)
+ out = {
+ "clean_acc": clean_success.float().mean().item(),
+ "mean_sign_exact": exact.float().mean().item(),
+ "mean_sign_token_acc": token.mean().item(),
+ "worst_sign_exact": both.float().mean().item(),
+ "best_sign_exact": either.float().mean().item(),
+ "selected_growth_mean": growth.mean().item(),
+ "selected_growth_q90": torch.quantile(growth, 0.90).item(),
+ }
+ if clean_success.any().item():
+ out["retain_mean_on_clean_success"] = exact[clean_success].float().mean().item()
+ out["retain_worst_on_clean_success"] = both[clean_success].float().mean().item()
+ else:
+ out["retain_mean_on_clean_success"] = float("nan")
+ out["retain_worst_on_clean_success"] = float("nan")
+ if clean_fail.any().item():
+ out["rescue_mean_on_clean_fail"] = exact[clean_fail].float().mean().item()
+ out["rescue_best_on_clean_fail"] = either[clean_fail].float().mean().item()
+ else:
+ out["rescue_mean_on_clean_fail"] = float("nan")
+ out["rescue_best_on_clean_fail"] = float("nan")
+ return out
+
+
+def write_summary(path: Path, rows: list[dict[str, Any]]) -> None:
+ keys = list(rows[0])
+ for row in rows[1:]:
+ for key in row:
+ if key not in keys:
+ keys.append(key)
+ with path.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=keys)
+ writer.writeheader()
+ writer.writerows(rows)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt-root", required=True)
+ parser.add_argument("--ckpt-name", required=True)
+ parser.add_argument("--label", required=True)
+ parser.add_argument("--n-samples", type=int, default=1000)
+ parser.add_argument("--batch-size", type=int, default=16)
+ parser.add_argument("--candidates", type=int, default=8)
+ parser.add_argument("--fd-eps", type=float, default=1e-3)
+ parser.add_argument("--sigmas", default="0,0.001,0.003,0.01,0.03,0.1")
+ parser.add_argument("--perturb-afters", default="0,4,8,12")
+ parser.add_argument("--seed", type=int, default=20260607)
+ parser.add_argument("--out-prefix", required=True)
+ args = parser.parse_args()
+
+ device = "cuda"
+ sigmas = parse_float_list(args.sigmas)
+ if not any(abs(s) <= 1e-12 for s in sigmas):
+ sigmas = [0.0] + sigmas
+ afters = parse_int_list(args.perturb_afters)
+ torch.manual_seed(args.seed)
+ generator = torch.Generator(device=device).manual_seed(args.seed + 17)
+
+ model, cfg, data_path = load_model(Path(args.ckpt_root), args.ckpt_name, device)
+ samples = load_test_samples(data_path, args.n_samples, args.seed)
+ n = len(samples["inputs"])
+ print(
+ f"[run] label={args.label} n={n} batch={args.batch_size} candidates={args.candidates} "
+ f"afters={afters} sigmas={sigmas}",
+ flush=True,
+ )
+
+ rows: list[dict[str, Any]] = []
+ exact_store = []
+ token_store = []
+ growth_store = []
+ for after in afters:
+ sigma_exact = {sigma: [] for sigma in sigmas}
+ sigma_token = {sigma: [] for sigma in sigmas}
+ growth_parts = []
+ for start in range(0, n, args.batch_size):
+ end = min(start + args.batch_size, n)
+ batch = batch_slice(samples, start, end, device)
+ warm, _ = warmup_inner(model, batch, after)
+ h_dir, l_dir, growth, _search_clean_exact, _search_clean_token = search_direction(
+ model, warm, batch, after, args.candidates, args.fd_eps, generator
+ )
+ growth_parts.append(growth.cpu())
+ for sigma in sigmas:
+ exact, token = eval_directional_sigma(model, warm, batch, after, h_dir, l_dir, sigma)
+ sigma_exact[sigma].append(exact.cpu())
+ sigma_token[sigma].append(token.cpu())
+ if end == n or (end // args.batch_size) % 10 == 0:
+ print(f" after={after} [{end}/{n}]", flush=True)
+
+ growth_all = torch.cat(growth_parts, dim=0)
+ growth_store.append(growth_all.numpy())
+ exact_by_sigma = {sigma: torch.cat(sigma_exact[sigma], dim=0) for sigma in sigmas}
+ token_by_sigma = {sigma: torch.cat(sigma_token[sigma], dim=0) for sigma in sigmas}
+ clean_all = exact_by_sigma[0.0][:, 0].clone()
+ print(f" after={after} clean grouping done clean_acc={clean_all.float().mean().item():.4f}", flush=True)
+ for sigma in sigmas:
+ exact_all = exact_by_sigma[sigma]
+ token_all = token_by_sigma[sigma]
+ row: dict[str, Any] = {
+ "label": args.label,
+ "perturb_after": after,
+ "sigma": sigma,
+ "n_samples": n,
+ "candidates": args.candidates,
+ "fd_eps": args.fd_eps,
+ "ckpt_root": str(Path(args.ckpt_root)),
+ "ckpt_name": args.ckpt_name,
+ **summarize(exact_all, token_all, clean_all, growth_all),
+ }
+ rows.append(row)
+ exact_store.append(exact_all.numpy())
+ token_store.append(token_all.numpy())
+ print(
+ f" after={after} sigma={sigma:g} clean={row['clean_acc']:.4f} "
+ f"mean={row['mean_sign_exact']:.4f} worst={row['worst_sign_exact']:.4f} "
+ f"retain_worst={row['retain_worst_on_clean_success']:.4f} "
+ f"rescue_best={row['rescue_best_on_clean_fail']:.4f}",
+ flush=True,
+ )
+
+ out_prefix = Path(args.out_prefix)
+ out_prefix.parent.mkdir(parents=True, exist_ok=True)
+ write_summary(out_prefix.with_suffix(".summary.csv"), rows)
+ meta = {
+ "args": vars(args),
+ "data_path": str(data_path),
+ "config_global_batch_size": cfg.get("global_batch_size"),
+ "sigmas": sigmas,
+ "perturb_afters": afters,
+ "n_samples": n,
+ }
+ out_prefix.with_suffix(".meta.json").write_text(json.dumps(meta, indent=2, sort_keys=True))
+ np.savez_compressed(
+ out_prefix.with_suffix(".npz"),
+ idx=samples["idx"],
+ sigmas=np.asarray(sigmas, dtype=np.float32),
+ perturb_afters=np.asarray(afters, dtype=np.int32),
+ exact=np.stack(exact_store, axis=0),
+ token_acc=np.stack(token_store, axis=0),
+ selected_growth=np.stack(growth_store, axis=0),
+ meta_json=np.asarray(json.dumps(meta, sort_keys=True)),
+ )
+ print(f"[done] {out_prefix}.summary.csv", flush=True)
+
+
+if __name__ == "__main__":
+ main()