summaryrefslogtreecommitdiff
path: root/research/flossing/sanity_lipschitz_check.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/sanity_lipschitz_check.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/sanity_lipschitz_check.py')
-rw-r--r--research/flossing/sanity_lipschitz_check.py184
1 files changed, 184 insertions, 0 deletions
diff --git a/research/flossing/sanity_lipschitz_check.py b/research/flossing/sanity_lipschitz_check.py
new file mode 100644
index 0000000..91fffee
--- /dev/null
+++ b/research/flossing/sanity_lipschitz_check.py
@@ -0,0 +1,184 @@
+"""Empirical Lipschitz sanity check: perturb init state by small noise,
+measure how OUTPUT and final z_H change. Independent of our JVP code.
+
+If TRM succ samples truly have λ > 0, perturbations should diverge through dynamics.
+If they're actually stable in output subspace, perturbations decay or stay bounded.
+"""
+import sys, yaml, json, math
+from pathlib import Path
+import numpy as np
+import torch
+
+HRM_DIR = Path("/home/yurenh2/rrm/hrm")
+TRM_DIR = Path("/home/yurenh2/rrm/trm")
+
+CKPT_TRM_ROOT = "/home/yurenh2/rrm/trm/checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_mlp_t_sudoku_singleGPU"
+CKPT_TRM_NAME = "step_104164"
+CKPT_HRM_ROOT = "/home/yurenh2/rrm/hrm/checkpoints/Sudoku-extreme-1k-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 righteous-python"
+CKPT_HRM_NAME = "step_26040"
+
+DEVICE = "cuda"
+
+
+def load_model(repo_dir, ckpt_root, ckpt_name, model_cls_path):
+ # Clear cached modules from other repo to avoid conflicts (HRM/TRM both have models.*)
+ for mod in list(sys.modules.keys()):
+ if mod.startswith("models"):
+ del sys.modules[mod]
+ sys.path[:] = [p for p in sys.path if not (p.endswith("/hrm") or p.endswith("/trm"))]
+ sys.path.insert(0, str(repo_dir))
+ import importlib
+ mod_path, cls_name = model_cls_path.split("@")
+ cls = getattr(importlib.import_module(mod_path), cls_name)
+ cfg = yaml.safe_load((Path(ckpt_root) / "all_config.yaml").read_text())
+ arch_cfg = dict(cfg["arch"])
+ data_path = Path(cfg.get("data_path") or cfg["data_paths"][0])
+ train_meta = json.loads((data_path / "train" / "dataset.json").read_text())
+ arch_cfg.update(batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"],
+ vocab_size=train_meta["vocab_size"],
+ num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], causal=False)
+ model = cls(arch_cfg)
+ sd = torch.load(Path(ckpt_root) / ckpt_name, map_location="cpu", weights_only=True)
+ sd = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in sd.items()}
+ missing, unexpected = model.load_state_dict(sd, strict=False)
+ print(f" [load] missing={len(missing)} unexpected={len(unexpected)}")
+ model.to(DEVICE).eval()
+ return model, cfg, train_meta, data_path
+
+
+def load_test_samples(data_path, n_total, seed=0):
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "test" / "all__inputs.npy")
+ labels = np.load(data_path / "test" / "all__labels.npy")
+ pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy")
+ idx = rng.choice(len(inputs), size=n_total, replace=False)
+ return {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)).to(DEVICE),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)).to(DEVICE),
+ "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)).to(DEVICE),
+ }
+
+
+@torch.no_grad()
+def measure_pert_stability(model, batch, eps=1e-2, n_act_steps=8):
+ """For each sample, run UNPERTURBED + PERTURBED full forward.
+ Track:
+ - δz_H (final): norm of z_H change at end of all ACT steps
+ - δz_L (final): same for z_L
+ - argmax flip: did the prediction change?
+ Returns per-sample stats.
+
+ The "growth rate" inferred from δz_final / δz_init can be compared to JVP λ.
+ If λ_JVP > 0, δz_final >> δz_init (expansion). If λ_JVP < 0, δz_final < δz_init.
+ """
+ inner = model.inner
+ cfg = inner.config
+ B = batch["inputs"].shape[0]
+ seq_full = cfg.seq_len + inner.puzzle_emb_len
+ hidden = cfg.hidden_size
+ dt = inner.forward_dtype
+
+ # Initial z_H, z_L (identical for both runs initially)
+ z_H_0 = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(dt)
+ z_L_0 = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(dt)
+
+ # Perturbation
+ g = torch.Generator(device=DEVICE).manual_seed(42)
+ delta_H = torch.randn(B, seq_full, hidden, generator=g, dtype=torch.float32, device=DEVICE).to(dt) * eps
+ delta_L = torch.randn(B, seq_full, hidden, generator=g, dtype=torch.float32, device=DEVICE).to(dt) * eps
+
+ input_emb = inner._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
+ seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None)
+ init_delta_norm = (delta_H.float().flatten(1).norm(dim=1) +
+ delta_L.float().flatten(1).norm(dim=1)) # (B,) sum of init pert norms
+
+ # Run unperturbed and perturbed in parallel
+ z_H_a, z_L_a = z_H_0.clone(), z_L_0.clone()
+ z_H_b, z_L_b = z_H_0 + delta_H, z_L_0 + delta_L
+
+ has_H_level = hasattr(inner, "H_level") # HRM has separate, TRM uses L_level for both
+
+ n_total = 0
+ for _act in range(n_act_steps):
+ for _h in range(cfg.H_cycles):
+ for _l in range(cfg.L_cycles):
+ z_L_a = inner.L_level(z_L_a, z_H_a + input_emb, **seq_info)
+ z_L_b = inner.L_level(z_L_b, z_H_b + input_emb, **seq_info)
+ n_total += 1
+ # H step: use H_level (HRM) or L_level (TRM)
+ h_mod = inner.H_level if has_H_level else inner.L_level
+ z_H_a = h_mod(z_H_a, z_L_a, **seq_info)
+ z_H_b = h_mod(z_H_b, z_L_b, **seq_info)
+ n_total += 1
+
+ final_delta_norm = ((z_H_b - z_H_a).float().flatten(1).norm(dim=1) +
+ (z_L_b - z_L_a).float().flatten(1).norm(dim=1))
+
+ # Per-sample growth rate per micro-step
+ # δ_final ≈ δ_init * exp(λ * n_total) → λ ≈ log(δ_final/δ_init) / n_total
+ ratio = final_delta_norm / init_delta_norm.clamp_min(1e-12)
+ lam_emp = ratio.log() / n_total
+
+ # Read out predictions for both runs
+ out_a = inner.lm_head(z_H_a)[:, inner.puzzle_emb_len:].float()
+ out_b = inner.lm_head(z_H_b)[:, inner.puzzle_emb_len:].float()
+ pred_a = out_a.argmax(dim=-1)
+ pred_b = out_b.argmax(dim=-1)
+ labels = batch["labels"]
+ mask = labels > 0
+ exact_a = ((pred_a == labels) | ~mask).all(dim=-1)
+ exact_b = ((pred_b == labels) | ~mask).all(dim=-1)
+ pred_flip = (pred_a != pred_b).any(dim=-1) # any token changed
+
+ return {
+ "init_norm": init_delta_norm.cpu(),
+ "final_norm": final_delta_norm.cpu(),
+ "ratio": ratio.cpu(),
+ "lam_emp": lam_emp.cpu(),
+ "succ_a": exact_a.cpu(),
+ "succ_b": exact_b.cpu(),
+ "pred_flip": pred_flip.cpu(),
+ }
+
+
+def main():
+ for name, repo, ckpt_root, ckpt_name, mod_path in [
+ ("HRM step_26040", HRM_DIR, CKPT_HRM_ROOT, CKPT_HRM_NAME,
+ "models.hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1"),
+ ("TRM step_104164", TRM_DIR, CKPT_TRM_ROOT, CKPT_TRM_NAME,
+ "models.recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1"),
+ ]:
+ print(f"\n=== {name} ===")
+ model, cfg, train_meta, data_path = load_model(repo, ckpt_root, ckpt_name, mod_path)
+ batch = load_test_samples(data_path, n_total=64, seed=0)
+ # Limit batch size to model's training batch (puzzle_emb buffer)
+ # Re-batch
+ B = 16
+ results = {"lam_emp": [], "succ_a": [], "ratio": []}
+ for s in range(0, 64, B):
+ e = min(s + B, 64)
+ mb = {k: v[s:e] for k, v in batch.items()}
+ # Rebuild model puzzle_emb buffer if needed — easier: ensure model's batch_size matches
+ r = measure_pert_stability(model, mb, eps=1e-2, n_act_steps=8)
+ results["lam_emp"].append(r["lam_emp"])
+ results["succ_a"].append(r["succ_a"])
+ results["ratio"].append(r["ratio"])
+ lam = torch.cat(results["lam_emp"]).numpy()
+ succ = torch.cat(results["succ_a"]).numpy()
+ ratio = torch.cat(results["ratio"]).numpy()
+ print(f" N=64 acc={succ.mean():.3f}")
+ print(f" finite-diff λ_emp (per micro-step):")
+ print(f" all mean={lam.mean():+.4f} med={np.median(lam):+.4f} range=[{lam.min():+.4f}, {lam.max():+.4f}]")
+ if succ.sum() > 0:
+ print(f" succ mean={lam[succ.astype(bool)].mean():+.4f} med={np.median(lam[succ.astype(bool)]):+.4f}")
+ if (~succ).sum() > 0:
+ print(f" fail mean={lam[~succ.astype(bool)].mean():+.4f} med={np.median(lam[~succ.astype(bool)]):+.4f}")
+ print(f" final/init perturbation ratio:")
+ print(f" all mean={ratio.mean():.3f} med={np.median(ratio):.3f} range=[{ratio.min():.3e}, {ratio.max():.3e}]")
+ # cleanup
+ del model
+ torch.cuda.empty_cache()
+
+
+if __name__ == "__main__":
+ main()