diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/step8_basin_consistency.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/step8_basin_consistency.py')
| -rw-r--r-- | research/flossing/step8_basin_consistency.py | 199 |
1 files changed, 199 insertions, 0 deletions
diff --git a/research/flossing/step8_basin_consistency.py b/research/flossing/step8_basin_consistency.py new file mode 100644 index 0000000..8bde719 --- /dev/null +++ b/research/flossing/step8_basin_consistency.py @@ -0,0 +1,199 @@ +"""Step 8: Hidden-state basin consistency regularization. + +This tests the "stabilize the correct basin, not the whole Lyapunov spectrum" +hypothesis. For each task update, we optionally: + 1. roll out a clean teacher trajectory to final logits; + 2. roll out to an intermediate recursive state; + 3. perturb z_H/z_L there; + 4. continue the rollout and penalize final-logit KL to the clean teacher. + +The supervised loss remains the primary task objective. The consistency term +directly asks nearby hidden states to land in the same answer basin. +""" +from __future__ import annotations + +import argparse +import json +import sys +import time +from dataclasses import replace +from pathlib import Path + +import torch +import torch.nn.functional as F + +FLOSS_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(FLOSS_DIR)) + +from step7_interfloss import ( # noqa: E402 + evaluate, + freeze_puzzle_embedding, + load_model, + load_train_batches, + move_batch, +) + + +def rollout_logits_eval(base, batch, device): + base.eval() + freeze_puzzle_embedding(base) + with torch.device(device): + carry = base.initial_carry(batch) + outputs = None + for _ in range(base.config.halt_max_steps): + carry, outputs = base(carry=carry, batch=batch) + return outputs["logits"] + + +def perturb_inner_carry(carry, noise_std: float): + if noise_std <= 0: + return carry + inner = carry.inner_carry + z_h = inner.z_H + noise_std * torch.randn_like(inner.z_H) + z_l = inner.z_L + noise_std * torch.randn_like(inner.z_L) + return replace(carry, inner_carry=replace(inner, z_H=z_h, z_L=z_l)) + + +def perturbed_rollout_logits(base, batch, device, perturb_after: int, noise_std: float): + base.eval() + freeze_puzzle_embedding(base) + with torch.device(device): + carry = base.initial_carry(batch) + outputs = None + warmup = min(max(perturb_after, 1), base.config.halt_max_steps - 1) + with torch.no_grad(): + for _ in range(warmup): + carry, outputs = base(carry=carry, batch=batch) + carry = perturb_inner_carry(carry, noise_std) + for _ in range(warmup, base.config.halt_max_steps): + carry, outputs = base(carry=carry, batch=batch) + return outputs["logits"] + + +def consistency_loss(args, base, batch, device): + with torch.no_grad(): + teacher_logits = rollout_logits_eval(base, batch, device).detach().to(torch.float32) + student_logits = perturbed_rollout_logits( + base, + batch, + device, + perturb_after=args.perturb_after, + noise_std=args.noise_std, + ).to(torch.float32) + temp = args.kl_temperature + teacher_p = F.softmax(teacher_logits / temp, dim=-1) + student_logp = F.log_softmax(student_logits / temp, dim=-1) + kl_per_token = F.kl_div(student_logp, teacher_p, reduction="none").sum(dim=-1) * (temp ** 2) + mask = batch["labels"] > 0 + if mask.any(): + return kl_per_token[mask].mean() + return kl_per_token.mean() + + +def supervised_backward(head, base, batch, device): + head.train() + freeze_puzzle_embedding(base) + with torch.device(device): + carry = base.initial_carry(batch) + sup_loss_sum = 0.0 + n_loss = 0 + for _ in range(base.config.halt_max_steps): + carry, loss, _metrics, _outputs, all_finish = head(return_keys=[], carry=carry, batch=batch) + sup_loss_sum = sup_loss_sum + loss + n_loss += 1 + if all_finish: + break + sup_loss = sup_loss_sum / max(n_loss, 1) / batch["inputs"].shape[0] + sup_loss.backward() + return sup_loss.detach() + + +def write_log(path: str, log: dict): + Path(path).write_text(json.dumps(log, indent=2)) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", choices=["hrm", "trm"], required=True) + parser.add_argument("--ckpt-root", required=True) + parser.add_argument("--ckpt-name", required=True) + parser.add_argument("--train-steps", type=int, default=10000) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--lr", type=float, default=1e-5) + parser.add_argument("--consistency-beta", type=float, default=1.0) + parser.add_argument("--consistency-every", type=int, default=1) + parser.add_argument("--perturb-after", type=int, default=8) + parser.add_argument("--noise-std", type=float, default=0.02) + parser.add_argument("--kl-temperature", type=float, default=1.0) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--eval-every", type=int, default=1000) + parser.add_argument("--eval-n", type=int, default=512) + parser.add_argument("--eval-batch-size", type=int, default=32) + parser.add_argument("--out", default="step8_basin_consistency_log.json") + args = parser.parse_args() + + device = "cuda" + head, base, cfg, adam_cls = load_model(args.model, Path(args.ckpt_root), args.ckpt_name, device) + data_path = Path(cfg["data_path"]) + optim = adam_cls(head.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=cfg["weight_decay"]) + + print(f"\n=== Initial eval (loaded {args.ckpt_name}) ===") + acc0, tok0 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print(f" initial: exact_acc={acc0:.4f} token_acc={tok0:.4f}", flush=True) + + log = { + "args": vars(args), + "initial_acc": acc0, + "initial_tok_acc": tok0, + "steps": [], + "evals": [{"step": 0, "acc": acc0, "tok_acc": tok0}], + } + write_log(args.out, log) + + train_iter = load_train_batches(data_path, args.batch_size, args.train_steps, seed=args.seed) + t0 = time.time() + for step, batch_cpu in enumerate(train_iter): + batch = move_batch(batch_cpu, device) + optim.zero_grad(set_to_none=True) + sup_loss = supervised_backward(head, base, batch, device) + + cons_loss = torch.zeros((), device=device) + if args.consistency_beta > 0 and step % args.consistency_every == 0: + cons_loss = consistency_loss(args, base, batch, device) + (args.consistency_beta * cons_loss).backward() + + torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0) + optim.step() + + rec = { + "step": step + 1, + "sup_loss": float(sup_loss.item()), + "consistency_loss": float(cons_loss.detach().item()), + "total_loss": float(sup_loss.item() + args.consistency_beta * cons_loss.detach().item()), + } + log["steps"].append(rec) + if step % 50 == 0 or step == args.train_steps - 1: + print( + f" [{step + 1:>5}/{args.train_steps}] dt={time.time() - t0:.1f}s " + f"sup={rec['sup_loss']:.4f} cons={rec['consistency_loss']:.6f}", + flush=True, + ) + + if (step + 1) % args.eval_every == 0: + acc, tok_acc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print(f" >> EVAL @ step {step + 1}: exact_acc={acc:.4f} delta={acc - acc0:+.4f}", flush=True) + log["evals"].append({"step": step + 1, "acc": acc, "tok_acc": tok_acc}) + write_log(args.out, log) + + acc_f, tok_f = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print("\n=== Final eval ===") + print(f" initial={acc0:.4f} final={acc_f:.4f} delta={acc_f - acc0:+.4f}", flush=True) + log["final_acc"] = acc_f + log["final_tok_acc"] = tok_f + log["evals"].append({"step": args.train_steps, "acc": acc_f, "tok_acc": tok_f}) + write_log(args.out, log) + print(f"log -> {args.out}") + + +if __name__ == "__main__": + main() |
