summaryrefslogtreecommitdiff
path: root/research/flossing/reverse_floss_finetune.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/reverse_floss_finetune.py')
-rw-r--r--research/flossing/reverse_floss_finetune.py289
1 files changed, 289 insertions, 0 deletions
diff --git a/research/flossing/reverse_floss_finetune.py b/research/flossing/reverse_floss_finetune.py
new file mode 100644
index 0000000..0e767ea
--- /dev/null
+++ b/research/flossing/reverse_floss_finetune.py
@@ -0,0 +1,289 @@
+"""Reverse-flossing fine-tune of a trained HRM checkpoint.
+
+L_total = L_HRM_ACT (existing supervised loss) + alpha * L_RF
+L_RF = mean over batch of max(0, mean_lambda - lambda_star) ** 2
+
+Computes finite-time top-k Lyapunov spectrum during the forward (with grad enabled
+so that L_RF is differentiable wrt model params), then averages and applies a hinge.
+"""
+from __future__ import annotations
+import sys, os, yaml, json, math, time, argparse
+from pathlib import Path
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+HRM_DIR = Path("/home/yurenh2/rrm/hrm")
+sys.path.insert(0, str(HRM_DIR))
+
+from models.hrm.hrm_act_v1 import (
+ HierarchicalReasoningModel_ACTV1,
+ HierarchicalReasoningModel_ACTV1InnerCarry,
+)
+from models.losses import ACTLossHead
+from adam_atan2 import AdamATan2
+
+# ----------------- helpers -----------------
+
+def load_model_for_finetune(ckpt_root: Path, ckpt_name: str, device: str):
+ cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text())
+ arch_cfg = dict(cfg["arch"])
+ train_meta = json.loads((Path(cfg["data_path"]) / "train" / "dataset.json").read_text())
+ arch_cfg.update(
+ batch_size=cfg["global_batch_size"],
+ seq_len=train_meta["seq_len"],
+ vocab_size=train_meta["vocab_size"],
+ num_puzzle_identifiers=train_meta["num_puzzle_identifiers"],
+ causal=False,
+ )
+ base = HierarchicalReasoningModel_ACTV1(arch_cfg)
+ head = ACTLossHead(base, loss_type=arch_cfg["loss"]["loss_type"])
+
+ sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True)
+ stripped = {}
+ for k, v in sd.items():
+ nk = k
+ for prefix in ("_orig_mod.",):
+ if nk.startswith(prefix): nk = nk[len(prefix):]
+ stripped[nk] = v
+ missing, unexpected = head.load_state_dict(stripped, strict=False)
+ print(f"[load] missing={len(missing)} unexpected={len(unexpected)}")
+ if missing[:3]: print(f" sample missing: {missing[:3]}")
+ if unexpected[:3]: print(f" sample unexpected: {unexpected[:3]}")
+ head.to(device)
+ return head, base, cfg, train_meta
+
+
+def jvp_apply(f, x, V):
+ """D_f(x) @ V where V is (B, state_dim, k). Uses create_graph=True so RF loss is
+ differentiable wrt model params."""
+ B, state_dim, k = V.shape
+ out = []
+ fx_last = None
+ for i in range(k):
+ v_i = V[..., i].view_as(x)
+ fx, Dv = torch.autograd.functional.jvp(f, x, v=v_i, create_graph=True, strict=False)
+ out.append(Dv.reshape(B, state_dim).to(torch.float32))
+ fx_last = fx
+ return fx_last, torch.stack(out, dim=-1)
+
+
+def compute_lyap_mean(model, inner, batch, k_lyap, device, seed, lyap_act_steps=4):
+ """Run forward with grad and accumulate top-k Lyapunov estimate per sample.
+ Returns mean_lambda per sample (B,) — DIFFERENTIABLE wrt params.
+
+ `lyap_act_steps`: how many ACT steps to unroll for the Lyapunov estimate.
+ Fewer steps → lower memory but noisier estimate. T = lyap_act_steps * (H*L + H).
+ """
+ cfg = inner.config
+ B = batch["inputs"].shape[0]
+ seq_full = inner.config.seq_len + inner.puzzle_emb_len
+ hidden = cfg.hidden_size
+ state_dim = seq_full * hidden
+
+ # Init
+ z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
+ z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
+ seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None)
+ input_embeddings = inner._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
+
+ g = torch.Generator(device=device).manual_seed(seed)
+ Q0 = torch.randn(B, state_dim, k_lyap, device=device, dtype=torch.float32, generator=g)
+ Q, _ = torch.linalg.qr(Q0)
+ log_R_sum = torch.zeros(B, k_lyap, device=device, dtype=torch.float32)
+ n_steps = 0
+
+ n_act = min(lyap_act_steps, cfg.halt_max_steps)
+ for _act in range(n_act):
+ for _h in range(cfg.H_cycles):
+ for _l in range(cfg.L_cycles):
+ f = lambda x: inner.L_level(x, z_H + input_embeddings, **seq_info)
+ z_L_new, DQ = jvp_apply(f, z_L, Q)
+ Q = DQ; z_L = z_L_new
+ Q, R = torch.linalg.qr(Q)
+ log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log()
+ n_steps += 1
+ f = lambda x: inner.H_level(x, z_L, **seq_info)
+ z_H_new, DQ = jvp_apply(f, z_H, Q)
+ Q = DQ; z_H = z_H_new
+ Q, R = torch.linalg.qr(Q)
+ log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log()
+ n_steps += 1
+
+ lyap_spec = log_R_sum / max(n_steps, 1) # (B, k)
+ return lyap_spec.mean(dim=-1) # (B,) mean over top-k
+
+
+# ----------------- data -----------------
+
+def load_train_batches(data_path: Path, batch_size: int, n_iters: int, seed: int = 0):
+ """Tiny iterator over the augmented train set."""
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "train" / "all__inputs.npy")
+ labels = np.load(data_path / "train" / "all__labels.npy")
+ pid = np.load(data_path / "train" / "all__puzzle_identifiers.npy")
+ N = len(inputs)
+ for _ in range(n_iters):
+ idx = rng.choice(N, size=batch_size, replace=False)
+ yield {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)),
+ "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)),
+ }
+
+
+def evaluate(head, base, data_path: Path, n_samples: int, batch_size: int, device: str, seed: int = 42):
+ """Quick exact-accuracy eval on test set."""
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "test" / "all__inputs.npy")
+ labels = np.load(data_path / "test" / "all__labels.npy")
+ pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy")
+ idx_all = rng.choice(len(inputs), size=n_samples, replace=False)
+ head.eval()
+ correct = 0; token_correct = 0; token_total = 0
+ for s in range(0, n_samples, batch_size):
+ e = min(s + batch_size, n_samples)
+ idx = idx_all[s:e]
+ batch = {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)).to(device),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)).to(device),
+ "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)).to(device),
+ }
+ with torch.no_grad():
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+ for _ in range(base.config.halt_max_steps):
+ carry, outputs = base(carry=carry, batch=batch)
+ preds = outputs["logits"].argmax(dim=-1)
+ mask = batch["labels"] > 0
+ exact = ((preds == batch["labels"]) | ~mask).all(dim=-1).float()
+ correct += exact.sum().item()
+ token_correct += ((preds == batch["labels"]) & mask).sum().item()
+ token_total += mask.sum().item()
+ return correct / n_samples, token_correct / max(token_total, 1)
+
+
+# ----------------- main -----------------
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--ckpt-root", required=True)
+ ap.add_argument("--ckpt-name", default="step_26040")
+ ap.add_argument("--n-steps", type=int, default=200)
+ ap.add_argument("--batch-size", type=int, default=16, help="kept small because RF JVPs use a lot of memory")
+ ap.add_argument("--lr", type=float, default=2e-5)
+ ap.add_argument("--alpha-rf", type=float, default=1.0)
+ ap.add_argument("--lambda-star", type=float, default=-0.85)
+ ap.add_argument("--k-lyap", type=int, default=2)
+ ap.add_argument("--lyap-act-steps", type=int, default=4, help="ACT steps to unroll for Lyapunov measurement")
+ ap.add_argument("--rf-mode", choices=["fixed","horizon"], default="fixed",
+ help="fixed: use --lambda-star as target. horizon: λ*=(1/T)log(eps/loss)")
+ ap.add_argument("--rf-eps", type=float, default=1e-6, help="task tolerance for horizon mode")
+ ap.add_argument("--seed", type=int, default=0)
+ ap.add_argument("--eval-every", type=int, default=50)
+ ap.add_argument("--eval-n", type=int, default=512)
+ ap.add_argument("--eval-batch-size", type=int, default=64)
+ ap.add_argument("--out", default="rf_finetune_log.json")
+ args = ap.parse_args()
+
+ device = "cuda"
+ head, base, cfg, train_meta = load_model_for_finetune(Path(args.ckpt_root), args.ckpt_name, device)
+
+ # Optimizer matching HRM training
+ optim = AdamATan2(head.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=cfg["weight_decay"])
+
+ # Initial eval
+ print(f"\n=== Initial eval (no fine-tune) ===")
+ acc0, tacc0 = evaluate(head, base, Path(cfg["data_path"]), args.eval_n, args.eval_batch_size, device)
+ print(f" initial exact_acc = {acc0:.4f} token_acc = {tacc0:.4f}")
+
+ log = {"args": vars(args), "initial_acc": acc0, "initial_tok_acc": tacc0, "steps": []}
+ t0 = time.time()
+
+ train_iter = load_train_batches(Path(cfg["data_path"]), args.batch_size, args.n_steps, seed=args.seed)
+
+ head.train()
+ # Keep sparse puzzle embedding in eval mode (its local_weights buffer is sized for
+ # the original training batch_size of 768; here we use a smaller batch). The
+ # puzzle_emb table is still in the model, just not updated during fine-tune.
+ base.inner.puzzle_emb.eval()
+ for p in base.inner.puzzle_emb.parameters():
+ p.requires_grad_(False)
+
+ for step, batch in enumerate(train_iter):
+ batch = {k: v.to(device) for k, v in batch.items()}
+ # Make sure model is in train mode but puzzle_emb stays in eval mode (its
+ # local_weights buffer is sized for the original 768 batch_size; we use a smaller batch).
+ head.train()
+ base.inner.puzzle_emb.eval()
+
+ # ---- Supervised ACT loss accumulated over all halt_max_steps ACT steps ----
+ # We use a fresh carry per fine-tune step; run ACT loop fully and sum losses.
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+ sup_loss_sum = 0.0
+ n_loss = 0
+ for _ in range(base.config.halt_max_steps):
+ carry, l, metrics, _, all_finish = head(return_keys=[], carry=carry, batch=batch)
+ sup_loss_sum = sup_loss_sum + l
+ n_loss += 1
+ if all_finish: break
+ # Average across ACT steps, then normalize by batch_size
+ sup_loss = sup_loss_sum / max(n_loss, 1) / args.batch_size
+
+ # ---- Reverse-flossing penalty ----
+ # New forward pass dedicated to computing finite-time mean λ (differentiable).
+ mean_lyap = compute_lyap_mean(head, base.inner, batch, args.k_lyap, device,
+ seed=args.seed + step, lyap_act_steps=args.lyap_act_steps)
+ # λ̄ is in nats per inner-cycle.
+ if args.rf_mode == "fixed":
+ lam_star = torch.tensor(args.lambda_star, device=device, dtype=torch.float32)
+ else:
+ # finite-horizon: λ* = (1/T) log(eps / r); r = current per-sample sup loss proxy (use total loss)
+ T = base.config.halt_max_steps * (base.config.H_cycles * base.config.L_cycles + base.config.H_cycles)
+ r = sup_loss.detach().clamp_min(1e-9)
+ lam_star = (1.0 / T) * torch.log(torch.tensor(args.rf_eps, device=device) / r)
+ # Cap at small negative so we don't ask for impossible contraction
+ lam_star = lam_star.clamp(min=-2.0, max=0.0)
+
+ excess = (mean_lyap - lam_star).clamp_min(0.0) # 0 if lyap < star (already contractive enough)
+ rf_loss = (excess ** 2).mean()
+
+ total_loss = sup_loss + args.alpha_rf * rf_loss
+
+ optim.zero_grad(set_to_none=True)
+ total_loss.backward()
+ torch.nn.utils.clip_grad_norm_(head.parameters(), 1.0)
+ optim.step()
+
+ lam_mean_val = mean_lyap.detach().mean().item()
+ rec = {"step": step, "sup_loss": float(sup_loss.item()), "rf_loss": float(rf_loss.item()),
+ "total_loss": float(total_loss.item()), "mean_lyap": lam_mean_val,
+ "lam_star": float(lam_star.mean().item() if lam_star.dim() else lam_star.item()),
+ "excess_frac_nonzero": float((excess > 0).float().mean().item())}
+ log["steps"].append(rec)
+
+ if step % 5 == 0 or step == args.n_steps - 1:
+ print(f" [{step:>4}/{args.n_steps}] dt={time.time()-t0:.1f}s "
+ f"sup={rec['sup_loss']:.4f} rf={rec['rf_loss']:.4f} "
+ f"lyap={lam_mean_val:+.4f} λ*={rec['lam_star']:+.4f} "
+ f"nz={rec['excess_frac_nonzero']:.2f}", flush=True)
+
+ if (step + 1) % args.eval_every == 0:
+ acc, tacc = evaluate(head, base, Path(cfg["data_path"]), args.eval_n, args.eval_batch_size, device)
+ print(f" >> EVAL step {step+1}: exact_acc={acc:.4f} (Δ from init: {acc-acc0:+.4f})")
+ log["steps"][-1]["eval_acc"] = acc
+ log["steps"][-1]["eval_tok_acc"] = tacc
+ head.train()
+
+ print(f"\n=== Final eval ===")
+ acc_f, tacc_f = evaluate(head, base, Path(cfg["data_path"]), args.eval_n, args.eval_batch_size, device)
+ print(f" initial: {acc0:.4f} → final: {acc_f:.4f} (Δ {acc_f-acc0:+.4f})")
+ log["final_acc"] = acc_f; log["final_tok_acc"] = tacc_f
+
+ Path(args.out).write_text(json.dumps(log, indent=2))
+ print(f"log saved → {args.out}")
+
+
+if __name__ == "__main__":
+ main()