diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-05-23 04:56:47 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-05-23 04:56:47 -0500 |
| commit | 152821462023690df5d2bf90812e1cb5b1ca7274 (patch) | |
| tree | 9359c27d81b41dc0372cb82ef9c0ec3540d254e3 /scripts | |
| parent | d11a0f6432e26c3243123d5e19aaf2702c76d64c (diff) | |
Add SRM training pipeline
- config/arch/srm_v1.yaml: arch config for pretrain.py integration
- scripts/train_srm.py: standalone from-scratch trainer based on step4
(HRM training infra adapted for SRM joint operator)
The arch.yaml exposes κ, η, α, n_iters, n_aol_layers as Hydra params.
train_srm.py adds joint Lyapunov diagnostic via JVP on srm_block to verify
λ_1 ≤ log((1-α)+α·κ) per micro-step. Smoke tested with hidden=128, n_iters=4
on Sudoku 1k: empirical Lip 0.28 << bound 0.90.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Diffstat (limited to 'scripts')
| -rw-r--r-- | scripts/train_srm.py | 300 |
1 files changed, 300 insertions, 0 deletions
diff --git a/scripts/train_srm.py b/scripts/train_srm.py new file mode 100644 index 0000000..036ab82 --- /dev/null +++ b/scripts/train_srm.py @@ -0,0 +1,300 @@ +"""Train SRM-Joint-AOL from scratch on Sudoku 1k (or any HRM-format dataset). + +By construction the SRM joint step is ≤ κ-Lipschitz in P-norm, so this trainer +uses ONLY supervised ACT loss — no CF regularizer needed. λ_1 is logged as +a diagnostic; it should stay ≤ log((1-α)+α·κ) per micro-step (e.g. -0.105 for κ=0.9, α=1). + +Usage (run from /home/yurenh2/rrm/srm/): + python scripts/train_srm.py --n-steps 3000 --batch-size 8 \ + --out runs/srm_v1_sudoku_3k.json \ + --save-ckpt ckpts/srm_v1_3k.pt +""" +from __future__ import annotations +import sys, os, json, math, time, argparse +from pathlib import Path +import numpy as np +import torch + +ROOT = Path("/home/yurenh2/rrm/srm") +sys.path.insert(0, str(ROOT)) + +from models.srm.srm_aol_v1 import ( + StableRecursionModel_ACTV1, + StableRecursionModel_ACTV1_Inner, + measure_lipschitz_constant, +) +from models.losses import ACTLossHead +from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed +from adam_atan2 import AdamATan2 + + +def build_srm_from_scratch(data_path: Path, batch_size: int, device: str, + hidden_size: int = 512, + n_iters: int = 12, + n_aol_layers: int = 2, + kappa: float = 0.9, + eta: float = 1.0, + alpha: float = 1.0): + train_meta = json.loads((data_path / "train" / "dataset.json").read_text()) + arch_cfg = dict( + hidden_size=hidden_size, + n_iters=n_iters, + n_aol_layers=n_aol_layers, + kappa=kappa, eta=eta, alpha=alpha, + halt_max_steps=16, halt_exploration_prob=0.1, + puzzle_emb_ndim=hidden_size, + batch_size=batch_size, + vocab_size=train_meta["vocab_size"], + seq_len=train_meta["seq_len"], + num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], + forward_dtype="bfloat16", + ) + with torch.device(device): + base = StableRecursionModel_ACTV1(arch_cfg) + head = ACTLossHead(base, loss_type="stablemax_cross_entropy") + return head, base, train_meta + + +@torch.no_grad() +def compute_joint_lyap_spec_srm(inner: StableRecursionModel_ACTV1_Inner, batch, k_lyap, n_iters_for_lyap, + device, seed): + """Top-k joint Lyapunov spectrum for SRM dynamics. + + Tangent: at each step the Jacobian J = ∂T/∂(h,l) is applied to all k orthonormal + columns via JVP. Then QR re-orthogonalize. + """ + cfg = inner.config + B = batch["inputs"].shape[0] + seq_full = cfg.seq_len + inner.puzzle_emb_len + hidden = cfg.hidden_size + D = seq_full * hidden + + z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + input_emb = inner._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) + + g = torch.Generator(device=device).manual_seed(seed) + Q0 = torch.randn(B, 2 * D, k_lyap, device=device, dtype=torch.float32, generator=g) + Q, _ = torch.linalg.qr(Q0) + log_R_sum = torch.zeros(B, k_lyap, device=device, dtype=torch.float32) + n_steps_lyap = 0 + + for _ in range(n_iters_for_lyap): + # JVP through srm_block w.r.t. (z_H, z_L) — one tangent column at a time + new_cols = [] + for i in range(k_lyap): + v_H = Q[:, :D, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) + v_L = Q[:, D:, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) + + def f(zH_zL): + zH, zL = zH_zL[:, :hidden, :].permute(0, 2, 1).contiguous(), zH_zL[:, hidden:, :].permute(0, 2, 1).contiguous() + hN, lN = inner.srm_block(zH, zL, input_emb) + return torch.stack([hN, lN], dim=1).reshape(B, 2 * hidden, seq_full) + + # Easier: use 2 JVPs separately if function takes (h, l) + def f_joint(zH, zL): + return inner.srm_block(zH, zL, input_emb) + (hN, lN), (dh_out, dl_out) = torch.autograd.functional.jvp( + f_joint, (z_H, z_L), v=(v_H, v_L), create_graph=False, strict=False) + dh_col = dh_out.reshape(B, D).to(torch.float32) + dl_col = dl_out.reshape(B, D).to(torch.float32) + new_cols.append(torch.cat([dh_col, dl_col], dim=-1)) + Q = torch.stack(new_cols, dim=-1) # (B, 2D, k) + # Advance state + z_H, z_L = hN, lN + # Orthonormalize + Q, R = torch.linalg.qr(Q) + log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() + n_steps_lyap += 1 + + return log_R_sum / max(n_steps_lyap, 1) # (B, k) + + +def load_train_batches(data_path: Path, batch_size: int, n_iters: int, seed: int = 0): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "train" / "all__inputs.npy") + labels = np.load(data_path / "train" / "all__labels.npy") + pid = np.load(data_path / "train" / "all__puzzle_identifiers.npy") + N = len(inputs) + for _ in range(n_iters): + idx = rng.choice(N, size=batch_size, replace=False) + yield { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), + "labels": torch.from_numpy(labels[idx].astype(np.int32)), + "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)), + } + + +def evaluate(head, base, data_path, n_samples, batch_size, device, seed=42): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "test" / "all__inputs.npy") + labels = np.load(data_path / "test" / "all__labels.npy") + pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy") + idx_all = rng.choice(len(inputs), size=n_samples, replace=False) + head.eval() + correct = 0; token_correct = 0; token_total = 0 + for s in range(0, n_samples, batch_size): + e = min(s + batch_size, n_samples) + idx = idx_all[s:e] + batch = { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)).to(device), + "labels": torch.from_numpy(labels[idx].astype(np.int32)).to(device), + "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)).to(device), + } + with torch.no_grad(): + with torch.device(device): + carry = base.initial_carry(batch) + for _ in range(base.config.halt_max_steps): + carry, outputs = base(carry=carry, batch=batch) + preds = outputs["logits"].argmax(dim=-1) + mask = batch["labels"] > 0 + exact = ((preds == batch["labels"]) | ~mask).all(dim=-1).float() + correct += exact.sum().item() + token_correct += ((preds == batch["labels"]) & mask).sum().item() + token_total += mask.sum().item() + return correct / n_samples, token_correct / max(token_total, 1) + + +def warmup_constant_lr(step, base_lr, warmup): + return base_lr * step / max(1, warmup) if step < warmup else base_lr + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--data-path", default="/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000") + ap.add_argument("--n-steps", type=int, default=3000) + ap.add_argument("--batch-size", type=int, default=8) + ap.add_argument("--lr", type=float, default=1e-4) + ap.add_argument("--puzzle-emb-lr", type=float, default=1e-4) + ap.add_argument("--warmup-steps", type=int, default=200) + ap.add_argument("--weight-decay", type=float, default=1.0) + # SRM specific + ap.add_argument("--hidden-size", type=int, default=512) + ap.add_argument("--n-iters", type=int, default=12) + ap.add_argument("--n-aol-layers", type=int, default=2) + ap.add_argument("--kappa", type=float, default=0.9) + ap.add_argument("--eta", type=float, default=1.0) + ap.add_argument("--alpha", type=float, default=1.0) + # Diagnostic + ap.add_argument("--k-lyap", type=int, default=2) + ap.add_argument("--lyap-iters", type=int, default=8, help="number of SRM steps for Lyapunov measurement") + ap.add_argument("--lyap-every", type=int, default=50, help="measure Lyapunov every N steps (expensive)") + # Eval / logging + ap.add_argument("--seed", type=int, default=42) + ap.add_argument("--eval-every", type=int, default=250) + ap.add_argument("--eval-n", type=int, default=512) + ap.add_argument("--eval-batch-size", type=int, default=32) + ap.add_argument("--out", required=True) + ap.add_argument("--save-ckpt", default="") + args = ap.parse_args() + + device = "cuda" + torch.manual_seed(args.seed); np.random.seed(args.seed) + data_path = Path(args.data_path) + head, base, train_meta = build_srm_from_scratch( + data_path, args.batch_size, device, + hidden_size=args.hidden_size, n_iters=args.n_iters, + n_aol_layers=args.n_aol_layers, + kappa=args.kappa, eta=args.eta, alpha=args.alpha, + ) + n_params = sum(p.numel() for p in head.parameters()) + print(f"Built SRM-AOL from scratch | params={n_params:,} | " + f"hidden={args.hidden_size} n_iters={args.n_iters} n_aol={args.n_aol_layers} " + f"κ={args.kappa} η={args.eta} α={args.alpha}") + + puzzle_emb_opt = CastedSparseEmbeddingSignSGD_Distributed( + base.inner.puzzle_emb.buffers(), lr=0, + weight_decay=args.weight_decay, world_size=1, + ) + main_opt = AdamATan2(head.parameters(), lr=0, betas=(0.9, 0.95), weight_decay=args.weight_decay) + + # Initial eval (random init baseline) + Lipschitz check + acc0, tacc0 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print(f"=== step 0 (random init): exact_acc = {acc0:.4f} token_acc = {tacc0:.4f} ===") + # Sample one batch for the initial Lipschitz check + probe_batch = next(load_train_batches(data_path, args.batch_size, 1, seed=999)) + probe_batch = {k: v.to(device) for k, v in probe_batch.items()} + lip0 = measure_lipschitz_constant(base.inner, probe_batch, n_probes=32) + print(f" Lip init: emp_max={lip0['lip_emp_max']:.4f} bound={lip0['lip_theoretical_bound']:.4f}") + + log = { + "args": vars(args), "n_params": n_params, + "initial_acc": acc0, "initial_tok_acc": tacc0, + "initial_lip": lip0, + "steps": [], "evals": [], + } + log["evals"].append({"step": 0, "acc": acc0, "tok_acc": tacc0}) + t0 = time.time() + train_iter = load_train_batches(data_path, args.batch_size, args.n_steps, seed=args.seed) + + for step, batch in enumerate(train_iter): + batch = {k: v.to(device) for k, v in batch.items()} + cur_lr = warmup_constant_lr(step, args.lr, args.warmup_steps) + cur_pe_lr = warmup_constant_lr(step, args.puzzle_emb_lr, args.warmup_steps) + for pg in main_opt.param_groups: pg["lr"] = cur_lr + for pg in puzzle_emb_opt.param_groups: pg["lr"] = cur_pe_lr + + head.train() + with torch.device(device): + carry = base.initial_carry(batch) + sup_loss_sum = 0.0; n_loss = 0 + for _ in range(base.config.halt_max_steps): + carry, l, metrics, _, all_finish = head(return_keys=[], carry=carry, batch=batch) + sup_loss_sum = sup_loss_sum + l + n_loss += 1 + if all_finish: break + sup_loss = sup_loss_sum / max(n_loss, 1) / args.batch_size + + puzzle_emb_opt.zero_grad(set_to_none=True) + main_opt.zero_grad(set_to_none=True) + sup_loss.backward() + torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0) + main_opt.step() + puzzle_emb_opt.step() + + rec = {"step": step, "lr": cur_lr, "sup_loss": float(sup_loss.item())} + + # Lyapunov diagnostic (every lyap_every steps) + if step % args.lyap_every == 0: + lyap_spec = compute_joint_lyap_spec_srm( + base.inner, batch, k_lyap=args.k_lyap, + n_iters_for_lyap=args.lyap_iters, + device=device, seed=args.seed + step, + ) # (B, k) + rec["lyap1_mean"] = float(lyap_spec[:, 0].mean().item()) + rec["lyap1_max"] = float(lyap_spec[:, 0].max().item()) + rec["lyap_spec_mean"] = lyap_spec.mean(dim=0).cpu().tolist() + log_kappa_bound = math.log((1 - args.alpha) + args.alpha * args.kappa) + rec["lyap_bound"] = log_kappa_bound + log["steps"].append(rec) + if step % 25 == 0 or step == args.n_steps - 1: + extra = f" λ={rec.get('lyap1_mean', float('nan')):+.4f} max={rec.get('lyap1_max', float('nan')):+.4f}" if "lyap1_mean" in rec else "" + print(f" [{step:>4}/{args.n_steps}] dt={time.time()-t0:.0f}s lr={cur_lr:.1e} " + f"sup={rec['sup_loss']:.4f}{extra}", flush=True) + + if (step + 1) % args.eval_every == 0 or step == args.n_steps - 1: + acc, tacc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print(f" >> EVAL @ {step+1}: exact_acc={acc:.4f} tok_acc={tacc:.4f} " + f"(Δ init: {acc-acc0:+.4f})", flush=True) + log["evals"].append({"step": step + 1, "acc": acc, "tok_acc": tacc}) + + log["final_acc"] = log["evals"][-1]["acc"] + log["final_tok_acc"] = log["evals"][-1]["tok_acc"] + Path(args.out).parent.mkdir(parents=True, exist_ok=True) + Path(args.out).write_text(json.dumps(log, indent=2)) + print(f"\n=== DONE === init {acc0:.4f} → final {log['final_acc']:.4f} log → {args.out}") + + if args.save_ckpt: + Path(args.save_ckpt).parent.mkdir(parents=True, exist_ok=True) + torch.save({ + "state_dict": head.state_dict(), + "args": vars(args), + "n_steps_trained": args.n_steps, + "final_acc": log["final_acc"], + "n_params": n_params, + }, args.save_ckpt) + print(f"checkpoint → {args.save_ckpt}") + + +if __name__ == "__main__": + main() |
