"""Step 8: Hidden-state basin consistency regularization. This tests the "stabilize the correct basin, not the whole Lyapunov spectrum" hypothesis. For each task update, we optionally: 1. roll out a clean teacher trajectory to final logits; 2. roll out to an intermediate recursive state; 3. perturb z_H/z_L there; 4. continue the rollout and penalize final-logit KL to the clean teacher. The supervised loss remains the primary task objective. The consistency term directly asks nearby hidden states to land in the same answer basin. """ from __future__ import annotations import argparse import json import sys import time from dataclasses import replace from pathlib import Path import torch import torch.nn.functional as F FLOSS_DIR = Path(__file__).resolve().parent sys.path.insert(0, str(FLOSS_DIR)) from step7_interfloss import ( # noqa: E402 evaluate, freeze_puzzle_embedding, load_model, load_train_batches, move_batch, ) def rollout_logits_eval(base, batch, device): base.eval() freeze_puzzle_embedding(base) with torch.device(device): carry = base.initial_carry(batch) outputs = None for _ in range(base.config.halt_max_steps): carry, outputs = base(carry=carry, batch=batch) return outputs["logits"] def perturb_inner_carry(carry, noise_std: float): if noise_std <= 0: return carry inner = carry.inner_carry z_h = inner.z_H + noise_std * torch.randn_like(inner.z_H) z_l = inner.z_L + noise_std * torch.randn_like(inner.z_L) return replace(carry, inner_carry=replace(inner, z_H=z_h, z_L=z_l)) def perturbed_rollout_logits(base, batch, device, perturb_after: int, noise_std: float): base.eval() freeze_puzzle_embedding(base) with torch.device(device): carry = base.initial_carry(batch) outputs = None warmup = min(max(perturb_after, 1), base.config.halt_max_steps - 1) with torch.no_grad(): for _ in range(warmup): carry, outputs = base(carry=carry, batch=batch) carry = perturb_inner_carry(carry, noise_std) for _ in range(warmup, base.config.halt_max_steps): carry, outputs = base(carry=carry, batch=batch) return outputs["logits"] def consistency_loss(args, base, batch, device): with torch.no_grad(): teacher_logits = rollout_logits_eval(base, batch, device).detach().to(torch.float32) student_logits = perturbed_rollout_logits( base, batch, device, perturb_after=args.perturb_after, noise_std=args.noise_std, ).to(torch.float32) temp = args.kl_temperature teacher_p = F.softmax(teacher_logits / temp, dim=-1) student_logp = F.log_softmax(student_logits / temp, dim=-1) kl_per_token = F.kl_div(student_logp, teacher_p, reduction="none").sum(dim=-1) * (temp ** 2) mask = batch["labels"] > 0 if mask.any(): return kl_per_token[mask].mean() return kl_per_token.mean() def supervised_backward(head, base, batch, device): head.train() freeze_puzzle_embedding(base) with torch.device(device): carry = base.initial_carry(batch) sup_loss_sum = 0.0 n_loss = 0 for _ in range(base.config.halt_max_steps): carry, loss, _metrics, _outputs, all_finish = head(return_keys=[], carry=carry, batch=batch) sup_loss_sum = sup_loss_sum + loss n_loss += 1 if all_finish: break sup_loss = sup_loss_sum / max(n_loss, 1) / batch["inputs"].shape[0] sup_loss.backward() return sup_loss.detach() def write_log(path: str, log: dict): Path(path).write_text(json.dumps(log, indent=2)) def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", choices=["hrm", "trm"], required=True) parser.add_argument("--ckpt-root", required=True) parser.add_argument("--ckpt-name", required=True) parser.add_argument("--train-steps", type=int, default=10000) parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--consistency-beta", type=float, default=1.0) parser.add_argument("--consistency-every", type=int, default=1) parser.add_argument("--perturb-after", type=int, default=8) parser.add_argument("--noise-std", type=float, default=0.02) parser.add_argument("--kl-temperature", type=float, default=1.0) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--eval-every", type=int, default=1000) parser.add_argument("--eval-n", type=int, default=512) parser.add_argument("--eval-batch-size", type=int, default=32) parser.add_argument("--out", default="step8_basin_consistency_log.json") args = parser.parse_args() device = "cuda" head, base, cfg, adam_cls = load_model(args.model, Path(args.ckpt_root), args.ckpt_name, device) data_path = Path(cfg["data_path"]) optim = adam_cls(head.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=cfg["weight_decay"]) print(f"\n=== Initial eval (loaded {args.ckpt_name}) ===") acc0, tok0 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) print(f" initial: exact_acc={acc0:.4f} token_acc={tok0:.4f}", flush=True) log = { "args": vars(args), "initial_acc": acc0, "initial_tok_acc": tok0, "steps": [], "evals": [{"step": 0, "acc": acc0, "tok_acc": tok0}], } write_log(args.out, log) train_iter = load_train_batches(data_path, args.batch_size, args.train_steps, seed=args.seed) t0 = time.time() for step, batch_cpu in enumerate(train_iter): batch = move_batch(batch_cpu, device) optim.zero_grad(set_to_none=True) sup_loss = supervised_backward(head, base, batch, device) cons_loss = torch.zeros((), device=device) if args.consistency_beta > 0 and step % args.consistency_every == 0: cons_loss = consistency_loss(args, base, batch, device) (args.consistency_beta * cons_loss).backward() torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0) optim.step() rec = { "step": step + 1, "sup_loss": float(sup_loss.item()), "consistency_loss": float(cons_loss.detach().item()), "total_loss": float(sup_loss.item() + args.consistency_beta * cons_loss.detach().item()), } log["steps"].append(rec) if step % 50 == 0 or step == args.train_steps - 1: print( f" [{step + 1:>5}/{args.train_steps}] dt={time.time() - t0:.1f}s " f"sup={rec['sup_loss']:.4f} cons={rec['consistency_loss']:.6f}", flush=True, ) if (step + 1) % args.eval_every == 0: acc, tok_acc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) print(f" >> EVAL @ step {step + 1}: exact_acc={acc:.4f} delta={acc - acc0:+.4f}", flush=True) log["evals"].append({"step": step + 1, "acc": acc, "tok_acc": tok_acc}) write_log(args.out, log) acc_f, tok_f = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) print("\n=== Final eval ===") print(f" initial={acc0:.4f} final={acc_f:.4f} delta={acc_f - acc0:+.4f}", flush=True) log["final_acc"] = acc_f log["final_tok_acc"] = tok_f log["evals"].append({"step": args.train_steps, "acc": acc_f, "tok_acc": tok_f}) write_log(args.out, log) print(f"log -> {args.out}") if __name__ == "__main__": main()