summaryrefslogtreecommitdiff
path: root/scripts/train_hrm_orth.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/train_hrm_orth.py')
-rw-r--r--scripts/train_hrm_orth.py203
1 files changed, 203 insertions, 0 deletions
diff --git a/scripts/train_hrm_orth.py b/scripts/train_hrm_orth.py
new file mode 100644
index 0000000..4d9868b
--- /dev/null
+++ b/scripts/train_hrm_orth.py
@@ -0,0 +1,203 @@
+"""Train HRM-Orth (orthogonal-patched HRM) from scratch on Sudoku.
+
+Per codex round 2 recommendation (Q6 pivot): patch HRM Block (attn+SwiGLU+rms_norm)
+with Lipschitz-bounded versions (cosine attn + OrthLinear+MaxMin + weighted residual).
+Keeps HRM's H_level/L_level/ACT framework intact.
+"""
+from __future__ import annotations
+import sys, os, json, math, time, argparse
+from pathlib import Path
+import numpy as np
+import torch
+
+ROOT = Path("/home/yurenh2/rrm/srm")
+sys.path.insert(0, str(ROOT))
+
+from models.srm.hrm_orth_v1 import HierarchicalReasoningModel_ACTV1 as HRMOrth
+from models.losses import ACTLossHead
+from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
+from adam_atan2 import AdamATan2
+
+
+def build_model(data_path: Path, batch_size: int, device: str,
+ hidden_size: int = 256, num_heads: int = 4,
+ H_cycles: int = 2, L_cycles: int = 2, H_layers: int = 4, L_layers: int = 4,
+ orth_s_min: float = 0.95, cosine_attn_tau: float = 8.0):
+ train_meta = json.loads((data_path / "train" / "dataset.json").read_text())
+ arch_cfg = dict(
+ H_cycles=H_cycles, H_layers=H_layers,
+ L_cycles=L_cycles, L_layers=L_layers,
+ expansion=4,
+ halt_exploration_prob=0.1,
+ halt_max_steps=16,
+ hidden_size=hidden_size,
+ num_heads=num_heads,
+ pos_encodings="rope",
+ puzzle_emb_ndim=hidden_size,
+ batch_size=batch_size,
+ vocab_size=train_meta["vocab_size"],
+ seq_len=train_meta["seq_len"],
+ num_puzzle_identifiers=train_meta["num_puzzle_identifiers"],
+ forward_dtype="bfloat16",
+ orth_s_min=orth_s_min,
+ cosine_attn_tau=cosine_attn_tau,
+ )
+ with torch.device(device):
+ base = HRMOrth(arch_cfg)
+ head = ACTLossHead(base, loss_type="stablemax_cross_entropy")
+ return head, base, train_meta
+
+
+def load_train_batches(data_path: Path, batch_size: int, n_iters: int, seed: int = 0):
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "train" / "all__inputs.npy")
+ labels = np.load(data_path / "train" / "all__labels.npy")
+ pid = np.load(data_path / "train" / "all__puzzle_identifiers.npy")
+ N = len(inputs)
+ for _ in range(n_iters):
+ idx = rng.choice(N, size=batch_size, replace=False)
+ yield {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)),
+ "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)),
+ }
+
+
+def evaluate(head, base, data_path, n_samples, batch_size, device, seed=42):
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "test" / "all__inputs.npy")
+ labels = np.load(data_path / "test" / "all__labels.npy")
+ pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy")
+ idx_all = rng.choice(len(inputs), size=n_samples, replace=False)
+ head.eval()
+ correct = 0; token_correct = 0; token_total = 0
+ for s in range(0, n_samples, batch_size):
+ e = min(s + batch_size, n_samples)
+ idx = idx_all[s:e]
+ batch = {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)).to(device),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)).to(device),
+ "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)).to(device),
+ }
+ with torch.no_grad():
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+ for _ in range(base.config.halt_max_steps):
+ carry, outputs = base(carry=carry, batch=batch)
+ preds = outputs["logits"].argmax(dim=-1)
+ mask = batch["labels"] > 0
+ exact = ((preds == batch["labels"]) | ~mask).all(dim=-1).float()
+ correct += exact.sum().item()
+ token_correct += ((preds == batch["labels"]) & mask).sum().item()
+ token_total += mask.sum().item()
+ return correct / n_samples, token_correct / max(token_total, 1)
+
+
+def warmup_constant_lr(step, base_lr, warmup):
+ return base_lr * step / max(1, warmup) if step < warmup else base_lr
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--data-path", default="/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000")
+ ap.add_argument("--n-steps", type=int, default=3000)
+ ap.add_argument("--batch-size", type=int, default=8)
+ ap.add_argument("--lr", type=float, default=1e-4)
+ ap.add_argument("--puzzle-emb-lr", type=float, default=1e-4)
+ ap.add_argument("--warmup-steps", type=int, default=200)
+ ap.add_argument("--weight-decay", type=float, default=1.0)
+ ap.add_argument("--hidden-size", type=int, default=256)
+ ap.add_argument("--num-heads", type=int, default=4)
+ ap.add_argument("--H-cycles", type=int, default=2)
+ ap.add_argument("--L-cycles", type=int, default=2)
+ ap.add_argument("--H-layers", type=int, default=4)
+ ap.add_argument("--L-layers", type=int, default=4)
+ ap.add_argument("--orth-s-min", type=float, default=0.95, help="min diag scale (weak orthogonality)")
+ ap.add_argument("--cosine-attn-tau", type=float, default=8.0)
+ ap.add_argument("--seed", type=int, default=42)
+ ap.add_argument("--eval-every", type=int, default=300)
+ ap.add_argument("--eval-n", type=int, default=512)
+ ap.add_argument("--eval-batch-size", type=int, default=32)
+ ap.add_argument("--out", required=True)
+ ap.add_argument("--save-ckpt", default="")
+ args = ap.parse_args()
+
+ device = "cuda"
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ data_path = Path(args.data_path)
+ head, base, train_meta = build_model(
+ data_path, args.batch_size, device,
+ hidden_size=args.hidden_size, num_heads=args.num_heads,
+ H_cycles=args.H_cycles, L_cycles=args.L_cycles,
+ H_layers=args.H_layers, L_layers=args.L_layers,
+ orth_s_min=args.orth_s_min, cosine_attn_tau=args.cosine_attn_tau,
+ )
+ n_params = sum(p.numel() for p in head.parameters())
+ print(f"Built HRM-Orth | params={n_params:,} | hidden={args.hidden_size} "
+ f"H_layers={args.H_layers} L_layers={args.L_layers} "
+ f"s_min={args.orth_s_min} τ={args.cosine_attn_tau}")
+
+ puzzle_emb_opt = CastedSparseEmbeddingSignSGD_Distributed(
+ base.inner.puzzle_emb.buffers(), lr=0, weight_decay=args.weight_decay, world_size=1)
+ main_opt = AdamATan2(head.parameters(), lr=0, betas=(0.9, 0.95), weight_decay=args.weight_decay)
+
+ acc0, tacc0 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f"=== step 0 (random init): exact_acc = {acc0:.4f} token_acc = {tacc0:.4f} ===")
+
+ log = {"args": vars(args), "n_params": n_params, "initial_acc": acc0, "initial_tok_acc": tacc0, "steps": [], "evals": []}
+ log["evals"].append({"step": 0, "acc": acc0, "tok_acc": tacc0})
+ t0 = time.time()
+ train_iter = load_train_batches(data_path, args.batch_size, args.n_steps, seed=args.seed)
+
+ for step, batch in enumerate(train_iter):
+ batch = {k: v.to(device) for k, v in batch.items()}
+ cur_lr = warmup_constant_lr(step, args.lr, args.warmup_steps)
+ cur_pe_lr = warmup_constant_lr(step, args.puzzle_emb_lr, args.warmup_steps)
+ for pg in main_opt.param_groups: pg["lr"] = cur_lr
+ for pg in puzzle_emb_opt.param_groups: pg["lr"] = cur_pe_lr
+
+ head.train()
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+ sup_loss_sum = 0.0; n_loss = 0
+ for _ in range(base.config.halt_max_steps):
+ carry, l, metrics, _, all_finish = head(return_keys=[], carry=carry, batch=batch)
+ sup_loss_sum = sup_loss_sum + l
+ n_loss += 1
+ if all_finish: break
+ sup_loss = sup_loss_sum / max(n_loss, 1) / args.batch_size
+
+ puzzle_emb_opt.zero_grad(set_to_none=True)
+ main_opt.zero_grad(set_to_none=True)
+ sup_loss.backward()
+ torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0)
+ main_opt.step()
+ puzzle_emb_opt.step()
+
+ rec = {"step": step, "lr": cur_lr, "sup_loss": float(sup_loss.item())}
+ log["steps"].append(rec)
+ if step % 25 == 0 or step == args.n_steps - 1:
+ print(f" [{step:>4}/{args.n_steps}] dt={time.time()-t0:.0f}s lr={cur_lr:.1e} "
+ f"sup={rec['sup_loss']:.4f}", flush=True)
+
+ if (step + 1) % args.eval_every == 0 or step == args.n_steps - 1:
+ acc, tacc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f" >> EVAL @ {step+1}: exact_acc={acc:.4f} tok_acc={tacc:.4f} (Δ init: {acc-acc0:+.4f})", flush=True)
+ log["evals"].append({"step": step + 1, "acc": acc, "tok_acc": tacc})
+
+ log["final_acc"] = log["evals"][-1]["acc"]
+ log["final_tok_acc"] = log["evals"][-1]["tok_acc"]
+ Path(args.out).parent.mkdir(parents=True, exist_ok=True)
+ Path(args.out).write_text(json.dumps(log, indent=2))
+ print(f"\n=== DONE === init {acc0:.4f} → final {log['final_acc']:.4f} log → {args.out}")
+
+ if args.save_ckpt:
+ Path(args.save_ckpt).parent.mkdir(parents=True, exist_ok=True)
+ torch.save({"state_dict": head.state_dict(), "args": vars(args),
+ "n_steps_trained": args.n_steps, "final_acc": log["final_acc"], "n_params": n_params},
+ args.save_ckpt)
+ print(f"checkpoint → {args.save_ckpt}")
+
+
+if __name__ == "__main__":
+ main()