diff options
Diffstat (limited to 'research/flossing/step4_from_scratch.py')
| -rw-r--r-- | research/flossing/step4_from_scratch.py | 334 |
1 files changed, 334 insertions, 0 deletions
diff --git a/research/flossing/step4_from_scratch.py b/research/flossing/step4_from_scratch.py new file mode 100644 index 0000000..e1dd738 --- /dev/null +++ b/research/flossing/step4_from_scratch.py @@ -0,0 +1,334 @@ +"""Step 4: From-scratch HRM training with CF regularizer (no checkpoint load). + +Tests whether forcing λ_joint_1 → λ* from step 0 affects learning trajectory. + +Hypothesis test: + - Baseline (α=0): λ naturally drifts toward HRM's attractor (~-0.15) + - CF λ*=0: forces λ to stay near 0 (edge of chaos) + - CF λ*=-0.15: enforces natural attractor from start (should be neutral) + +For each condition we track λ trajectory + acc + halt distribution at fixed steps. +""" +from __future__ import annotations +import sys, os, yaml, json, math, time, argparse +from pathlib import Path +import numpy as np +import torch +import torch.nn.functional as F + +HRM_DIR = Path("/home/yurenh2/rrm/hrm") +sys.path.insert(0, str(HRM_DIR)) + +from models.hrm.hrm_act_v1 import HierarchicalReasoningModel_ACTV1 +from models.losses import ACTLossHead +from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed +from adam_atan2 import AdamATan2 + + +def build_model_from_scratch(data_path: Path, batch_size: int, device: str, + hidden_size: int = 512, num_heads: int = 8): + """Build HRM with the official Sudoku-1k arch config but at our batch size + arbitrary hidden.""" + train_meta = json.loads((data_path / "train" / "dataset.json").read_text()) + arch_cfg = dict( + H_cycles=2, H_layers=4, + L_cycles=2, L_layers=4, + expansion=4, + halt_exploration_prob=0.1, + halt_max_steps=16, + hidden_size=hidden_size, + num_heads=num_heads, + pos_encodings="rope", + puzzle_emb_ndim=hidden_size, + loss=dict(loss_type="stablemax_cross_entropy", name="losses@ACTLossHead"), + batch_size=batch_size, + vocab_size=train_meta["vocab_size"], + seq_len=train_meta["seq_len"], + num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], + causal=False, + ) + with torch.device(device): + base = HierarchicalReasoningModel_ACTV1(arch_cfg) + head = ACTLossHead(base, loss_type=arch_cfg["loss"]["loss_type"]) + return head, base, train_meta + + +def jvp_train(f, x, v): + return torch.autograd.functional.jvp(f, x, v=v, create_graph=True, strict=False) + + +def compute_joint_lyap_spec(base, batch, k_lyap, lyap_act_steps, device, seed, with_grad=True): + """Returns FULL top-k Lyapunov spectrum (B, k), differentiable wrt theta.""" + inner = base.inner + cfg = inner.config + B = batch["inputs"].shape[0] + seq_full = cfg.seq_len + inner.puzzle_emb_len + hidden = cfg.hidden_size + D = seq_full * hidden + + z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None) + input_embeddings = inner._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) + + g = torch.Generator(device=device).manual_seed(seed) + Q0 = torch.randn(B, 2*D, k_lyap, device=device, dtype=torch.float32, generator=g) + Q, _ = torch.linalg.qr(Q0) + log_R_sum = torch.zeros(B, k_lyap, device=device, dtype=torch.float32) + n_steps = 0 + + jvp_fn = (lambda f, x, v: torch.autograd.functional.jvp(f, x, v=v, create_graph=with_grad, strict=False)) + + n_act = min(lyap_act_steps, cfg.halt_max_steps) + for _act in range(n_act): + for _h in range(cfg.H_cycles): + for _l in range(cfg.L_cycles): + v_H_j = Q[:, :D, :]; v_L_j = Q[:, D:, :] + v_comb = v_H_j + v_L_j + new_v_L_cols = [] + f_L = lambda z: inner.L_level(z, z_H + input_embeddings, **seq_info) + for i in range(k_lyap): + v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) + z_L_new, Dv = jvp_fn(f_L, z_L, v_i) + new_v_L_cols.append(Dv.reshape(B, D).to(torch.float32)) + new_v_L = torch.stack(new_v_L_cols, dim=-1) + Q = torch.cat([v_H_j, new_v_L], dim=1) + z_L = z_L_new + Q, R = torch.linalg.qr(Q) + log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() + n_steps += 1 + v_H_j = Q[:, :D, :]; v_L_j = Q[:, D:, :] + v_comb = v_H_j + v_L_j + new_v_H_cols = [] + f_H = lambda z: inner.H_level(z, z_L, **seq_info) + for i in range(k_lyap): + v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) + z_H_new, Dv = jvp_fn(f_H, z_H, v_i) + new_v_H_cols.append(Dv.reshape(B, D).to(torch.float32)) + new_v_H = torch.stack(new_v_H_cols, dim=-1) + Q = torch.cat([new_v_H, v_L_j], dim=1) + z_H = z_H_new + Q, R = torch.linalg.qr(Q) + log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() + n_steps += 1 + + lyap_spec = log_R_sum / max(n_steps, 1) + return lyap_spec # (B, k_lyap) + + +def load_train_batches(data_path: Path, batch_size: int, n_iters: int, seed: int = 0): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "train" / "all__inputs.npy") + labels = np.load(data_path / "train" / "all__labels.npy") + pid = np.load(data_path / "train" / "all__puzzle_identifiers.npy") + N = len(inputs) + for _ in range(n_iters): + idx = rng.choice(N, size=batch_size, replace=False) + yield { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), + "labels": torch.from_numpy(labels[idx].astype(np.int32)), + "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)), + } + + +def evaluate(head, base, data_path, n_samples, batch_size, device, seed=42): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "test" / "all__inputs.npy") + labels = np.load(data_path / "test" / "all__labels.npy") + pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy") + idx_all = rng.choice(len(inputs), size=n_samples, replace=False) + head.eval() + correct = 0; token_correct = 0; token_total = 0 + for s in range(0, n_samples, batch_size): + e = min(s + batch_size, n_samples) + idx = idx_all[s:e] + batch = { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)).to(device), + "labels": torch.from_numpy(labels[idx].astype(np.int32)).to(device), + "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)).to(device), + } + with torch.no_grad(): + with torch.device(device): + carry = base.initial_carry(batch) + for _ in range(base.config.halt_max_steps): + carry, outputs = base(carry=carry, batch=batch) + preds = outputs["logits"].argmax(dim=-1) + mask = batch["labels"] > 0 + exact = ((preds == batch["labels"]) | ~mask).all(dim=-1).float() + correct += exact.sum().item() + token_correct += ((preds == batch["labels"]) & mask).sum().item() + token_total += mask.sum().item() + return correct / n_samples, token_correct / max(token_total, 1) + + +def warmup_constant_lr(step, base_lr, warmup): + if step < warmup: + return base_lr * step / max(1, warmup) + return base_lr + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--data-path", default="/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000") + ap.add_argument("--n-steps", type=int, default=2500) + ap.add_argument("--batch-size", type=int, default=8) + ap.add_argument("--lr", type=float, default=1e-4) + ap.add_argument("--puzzle-emb-lr", type=float, default=1e-4) + ap.add_argument("--warmup-steps", type=int, default=200) + ap.add_argument("--weight-decay", type=float, default=1.0) + ap.add_argument("--hidden-size", type=int, default=512, help="HRM hidden dim (default official 512)") + ap.add_argument("--num-heads", type=int, default=8, help="must divide hidden_size; default 8") + ap.add_argument("--alpha-rf", type=float, default=0.0, help="0 = baseline; >0 = CF/Engelken") + ap.add_argument("--rf-mode", choices=["fixed", "volume_cf", "engelken_l2"], default="fixed", + help="fixed: hinge max(0, λ_1-λ*)² on top-1; " + "volume_cf: hinge max(0, mean_i λ_i-λ*)² over top-k; " + "engelken_l2: (1/k) Σ λ_i² across full top-k spectrum") + ap.add_argument("--lambda-star", type=float, default=0.0, help="used in fixed and volume_cf modes") + ap.add_argument("--k-lyap", type=int, default=2) + ap.add_argument("--lyap-act-steps", type=int, default=4) + ap.add_argument("--seed", type=int, default=42) + ap.add_argument("--eval-every", type=int, default=250) + ap.add_argument("--eval-n", type=int, default=512) + ap.add_argument("--eval-batch-size", type=int, default=32) + ap.add_argument("--out", required=True) + ap.add_argument("--save-ckpt", default="", help="path to save final model state_dict (empty = skip)") + args = ap.parse_args() + + device = "cuda" + torch.manual_seed(args.seed); np.random.seed(args.seed) + data_path = Path(args.data_path) + head, base, train_meta = build_model_from_scratch(data_path, args.batch_size, device, + hidden_size=args.hidden_size, + num_heads=args.num_heads) + print(f"Built HRM from scratch | params={sum(p.numel() for p in head.parameters()):,} | " + f"vocab={train_meta['vocab_size']} seq={train_meta['seq_len']} " + f"num_pids={train_meta['num_puzzle_identifiers']}") + + # Two optimizers: SignSGD for puzzle_emb (sparse), AdamATan2 for rest + puzzle_emb_opt = CastedSparseEmbeddingSignSGD_Distributed( + base.inner.puzzle_emb.buffers(), + lr=0, + weight_decay=args.weight_decay, + world_size=1, + ) + main_opt = AdamATan2(head.parameters(), lr=0, betas=(0.9, 0.95), weight_decay=args.weight_decay) + + # Baseline eval (random init) + acc0, tacc0 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print(f"=== step 0 (random init): exact_acc = {acc0:.4f} token_acc = {tacc0:.4f} ===") + + log = {"args": vars(args), "initial_acc": acc0, "initial_tok_acc": tacc0, "steps": [], "evals": []} + log["evals"].append({"step": 0, "acc": acc0, "tok_acc": tacc0}) + t0 = time.time() + train_iter = load_train_batches(data_path, args.batch_size, args.n_steps, seed=args.seed) + + for step, batch in enumerate(train_iter): + batch = {k: v.to(device) for k, v in batch.items()} + + # Update LR + cur_lr = warmup_constant_lr(step, args.lr, args.warmup_steps) + cur_pe_lr = warmup_constant_lr(step, args.puzzle_emb_lr, args.warmup_steps) + for pg in main_opt.param_groups: pg["lr"] = cur_lr + for pg in puzzle_emb_opt.param_groups: pg["lr"] = cur_pe_lr + + head.train() + + # ACT loss + with torch.device(device): + carry = base.initial_carry(batch) + sup_loss_sum = 0.0; n_loss = 0 + for _ in range(base.config.halt_max_steps): + carry, l, metrics, _, all_finish = head(return_keys=[], carry=carry, batch=batch) + sup_loss_sum = sup_loss_sum + l + n_loss += 1 + if all_finish: break + sup_loss = sup_loss_sum / max(n_loss, 1) / args.batch_size + + # CF / Engelken loss (skip if alpha=0; still measure λ for logging in baseline) + if args.alpha_rf > 0: + lyap_spec = compute_joint_lyap_spec(base, batch, args.k_lyap, args.lyap_act_steps, + device, args.seed + step, with_grad=True) + lyap1 = lyap_spec[:, 0] + if args.rf_mode == "engelken_l2": + # Engelken: L = (1/k) Σ_i λ_i² over batch → push all spectrum toward 0 + rf_loss = (lyap_spec ** 2).mean() + excess = lyap1 # for logging (no hinge here) + elif args.rf_mode == "volume_cf": + # One-sided cap on local phase-space volume expansion. + # Allows λ_1 > 0 when compensated by contraction in other measured modes. + lyap_volume = lyap_spec.mean(dim=1) + excess = (lyap_volume - args.lambda_star).clamp_min(0.0) + rf_loss = (excess ** 2).mean() + else: # fixed (hinge on top-1) + excess = (lyap1 - args.lambda_star).clamp_min(0.0) + rf_loss = (excess ** 2).mean() + else: + with torch.no_grad(): + lyap_spec = compute_joint_lyap_spec(base, batch, args.k_lyap, args.lyap_act_steps, + device, args.seed + step, with_grad=False) + lyap1 = lyap_spec[:, 0] + rf_loss = torch.zeros((), device=device) + excess = (lyap1 - args.lambda_star).clamp_min(0.0) + + total_loss = sup_loss + args.alpha_rf * rf_loss + + puzzle_emb_opt.zero_grad(set_to_none=True) + main_opt.zero_grad(set_to_none=True) + total_loss.backward() + torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0) + main_opt.step() + puzzle_emb_opt.step() + + with torch.no_grad(): + lyap_mean_per_i = lyap_spec.detach().mean(dim=0).cpu().tolist() + rec = { + "step": step, "lr": cur_lr, + "sup_loss": float(sup_loss.item()), + "rf_loss": float(rf_loss.item()), + "total_loss": float(total_loss.item()), + "lyap1_mean": float(lyap1.detach().mean().item()), + "lyap1_max": float(lyap1.detach().max().item()), + "lyap1_min": float(lyap1.detach().min().item()), + "lyap_volume_mean": float(lyap_spec.detach().mean(dim=1).mean().item()), + "lyap_volume_max": float(lyap_spec.detach().mean(dim=1).max().item()), + "lyap_spec_mean": lyap_mean_per_i, + "frac_above_star": float((excess > 0).float().mean().item()), + } + log["steps"].append(rec) + if step % 25 == 0 or step == args.n_steps - 1: + print(f" [{step:>4}/{args.n_steps}] dt={time.time()-t0:.0f}s lr={cur_lr:.1e} " + f"sup={rec['sup_loss']:.4f} rf={rec['rf_loss']:.4f} " + f"λ_mean={rec['lyap1_mean']:+.4f} [{rec['lyap1_min']:+.3f},{rec['lyap1_max']:+.3f}] " + f"frac>λ*={rec['frac_above_star']:.2f}", flush=True) + + if (step + 1) % args.eval_every == 0 or step == args.n_steps - 1: + acc, tacc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print(f" >> EVAL @ {step+1}: exact_acc={acc:.4f} tok_acc={tacc:.4f} " + f"(Δ from init: {acc-acc0:+.4f})", flush=True) + log["evals"].append({"step": step + 1, "acc": acc, "tok_acc": tacc}) + + log["final_acc"] = log["evals"][-1]["acc"] + log["final_tok_acc"] = log["evals"][-1]["tok_acc"] + Path(args.out).write_text(json.dumps(log, indent=2)) + print(f"\n=== DONE === init {acc0:.4f} → final {log['final_acc']:.4f} log → {args.out}") + + if args.save_ckpt: + save_path = Path(args.save_ckpt) + save_path.parent.mkdir(parents=True, exist_ok=True) + torch.save({ + "state_dict": head.state_dict(), + "args": vars(args), + "n_steps_trained": args.n_steps, + "final_acc": log["final_acc"], + "final_tok_acc": log["final_tok_acc"], + "arch_cfg_signature": { + "vocab_size": train_meta["vocab_size"], + "seq_len": train_meta["seq_len"], + "num_puzzle_identifiers": train_meta["num_puzzle_identifiers"], + "batch_size": args.batch_size, + }, + }, save_path) + print(f"checkpoint saved → {save_path}") + + +if __name__ == "__main__": + main() |
