summaryrefslogtreecommitdiff
path: root/research/flossing/step8_basin_consistency.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/step8_basin_consistency.py')
-rw-r--r--research/flossing/step8_basin_consistency.py199
1 files changed, 199 insertions, 0 deletions
diff --git a/research/flossing/step8_basin_consistency.py b/research/flossing/step8_basin_consistency.py
new file mode 100644
index 0000000..8bde719
--- /dev/null
+++ b/research/flossing/step8_basin_consistency.py
@@ -0,0 +1,199 @@
+"""Step 8: Hidden-state basin consistency regularization.
+
+This tests the "stabilize the correct basin, not the whole Lyapunov spectrum"
+hypothesis. For each task update, we optionally:
+ 1. roll out a clean teacher trajectory to final logits;
+ 2. roll out to an intermediate recursive state;
+ 3. perturb z_H/z_L there;
+ 4. continue the rollout and penalize final-logit KL to the clean teacher.
+
+The supervised loss remains the primary task objective. The consistency term
+directly asks nearby hidden states to land in the same answer basin.
+"""
+from __future__ import annotations
+
+import argparse
+import json
+import sys
+import time
+from dataclasses import replace
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F
+
+FLOSS_DIR = Path(__file__).resolve().parent
+sys.path.insert(0, str(FLOSS_DIR))
+
+from step7_interfloss import ( # noqa: E402
+ evaluate,
+ freeze_puzzle_embedding,
+ load_model,
+ load_train_batches,
+ move_batch,
+)
+
+
+def rollout_logits_eval(base, batch, device):
+ base.eval()
+ freeze_puzzle_embedding(base)
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+ outputs = None
+ for _ in range(base.config.halt_max_steps):
+ carry, outputs = base(carry=carry, batch=batch)
+ return outputs["logits"]
+
+
+def perturb_inner_carry(carry, noise_std: float):
+ if noise_std <= 0:
+ return carry
+ inner = carry.inner_carry
+ z_h = inner.z_H + noise_std * torch.randn_like(inner.z_H)
+ z_l = inner.z_L + noise_std * torch.randn_like(inner.z_L)
+ return replace(carry, inner_carry=replace(inner, z_H=z_h, z_L=z_l))
+
+
+def perturbed_rollout_logits(base, batch, device, perturb_after: int, noise_std: float):
+ base.eval()
+ freeze_puzzle_embedding(base)
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+ outputs = None
+ warmup = min(max(perturb_after, 1), base.config.halt_max_steps - 1)
+ with torch.no_grad():
+ for _ in range(warmup):
+ carry, outputs = base(carry=carry, batch=batch)
+ carry = perturb_inner_carry(carry, noise_std)
+ for _ in range(warmup, base.config.halt_max_steps):
+ carry, outputs = base(carry=carry, batch=batch)
+ return outputs["logits"]
+
+
+def consistency_loss(args, base, batch, device):
+ with torch.no_grad():
+ teacher_logits = rollout_logits_eval(base, batch, device).detach().to(torch.float32)
+ student_logits = perturbed_rollout_logits(
+ base,
+ batch,
+ device,
+ perturb_after=args.perturb_after,
+ noise_std=args.noise_std,
+ ).to(torch.float32)
+ temp = args.kl_temperature
+ teacher_p = F.softmax(teacher_logits / temp, dim=-1)
+ student_logp = F.log_softmax(student_logits / temp, dim=-1)
+ kl_per_token = F.kl_div(student_logp, teacher_p, reduction="none").sum(dim=-1) * (temp ** 2)
+ mask = batch["labels"] > 0
+ if mask.any():
+ return kl_per_token[mask].mean()
+ return kl_per_token.mean()
+
+
+def supervised_backward(head, base, batch, device):
+ head.train()
+ freeze_puzzle_embedding(base)
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+ sup_loss_sum = 0.0
+ n_loss = 0
+ for _ in range(base.config.halt_max_steps):
+ carry, loss, _metrics, _outputs, all_finish = head(return_keys=[], carry=carry, batch=batch)
+ sup_loss_sum = sup_loss_sum + loss
+ n_loss += 1
+ if all_finish:
+ break
+ sup_loss = sup_loss_sum / max(n_loss, 1) / batch["inputs"].shape[0]
+ sup_loss.backward()
+ return sup_loss.detach()
+
+
+def write_log(path: str, log: dict):
+ Path(path).write_text(json.dumps(log, indent=2))
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", choices=["hrm", "trm"], required=True)
+ parser.add_argument("--ckpt-root", required=True)
+ parser.add_argument("--ckpt-name", required=True)
+ parser.add_argument("--train-steps", type=int, default=10000)
+ parser.add_argument("--batch-size", type=int, default=8)
+ parser.add_argument("--lr", type=float, default=1e-5)
+ parser.add_argument("--consistency-beta", type=float, default=1.0)
+ parser.add_argument("--consistency-every", type=int, default=1)
+ parser.add_argument("--perturb-after", type=int, default=8)
+ parser.add_argument("--noise-std", type=float, default=0.02)
+ parser.add_argument("--kl-temperature", type=float, default=1.0)
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--eval-every", type=int, default=1000)
+ parser.add_argument("--eval-n", type=int, default=512)
+ parser.add_argument("--eval-batch-size", type=int, default=32)
+ parser.add_argument("--out", default="step8_basin_consistency_log.json")
+ args = parser.parse_args()
+
+ device = "cuda"
+ head, base, cfg, adam_cls = load_model(args.model, Path(args.ckpt_root), args.ckpt_name, device)
+ data_path = Path(cfg["data_path"])
+ optim = adam_cls(head.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=cfg["weight_decay"])
+
+ print(f"\n=== Initial eval (loaded {args.ckpt_name}) ===")
+ acc0, tok0 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f" initial: exact_acc={acc0:.4f} token_acc={tok0:.4f}", flush=True)
+
+ log = {
+ "args": vars(args),
+ "initial_acc": acc0,
+ "initial_tok_acc": tok0,
+ "steps": [],
+ "evals": [{"step": 0, "acc": acc0, "tok_acc": tok0}],
+ }
+ write_log(args.out, log)
+
+ train_iter = load_train_batches(data_path, args.batch_size, args.train_steps, seed=args.seed)
+ t0 = time.time()
+ for step, batch_cpu in enumerate(train_iter):
+ batch = move_batch(batch_cpu, device)
+ optim.zero_grad(set_to_none=True)
+ sup_loss = supervised_backward(head, base, batch, device)
+
+ cons_loss = torch.zeros((), device=device)
+ if args.consistency_beta > 0 and step % args.consistency_every == 0:
+ cons_loss = consistency_loss(args, base, batch, device)
+ (args.consistency_beta * cons_loss).backward()
+
+ torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0)
+ optim.step()
+
+ rec = {
+ "step": step + 1,
+ "sup_loss": float(sup_loss.item()),
+ "consistency_loss": float(cons_loss.detach().item()),
+ "total_loss": float(sup_loss.item() + args.consistency_beta * cons_loss.detach().item()),
+ }
+ log["steps"].append(rec)
+ if step % 50 == 0 or step == args.train_steps - 1:
+ print(
+ f" [{step + 1:>5}/{args.train_steps}] dt={time.time() - t0:.1f}s "
+ f"sup={rec['sup_loss']:.4f} cons={rec['consistency_loss']:.6f}",
+ flush=True,
+ )
+
+ if (step + 1) % args.eval_every == 0:
+ acc, tok_acc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f" >> EVAL @ step {step + 1}: exact_acc={acc:.4f} delta={acc - acc0:+.4f}", flush=True)
+ log["evals"].append({"step": step + 1, "acc": acc, "tok_acc": tok_acc})
+ write_log(args.out, log)
+
+ acc_f, tok_f = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print("\n=== Final eval ===")
+ print(f" initial={acc0:.4f} final={acc_f:.4f} delta={acc_f - acc0:+.4f}", flush=True)
+ log["final_acc"] = acc_f
+ log["final_tok_acc"] = tok_f
+ log["evals"].append({"step": args.train_steps, "acc": acc_f, "tok_acc": tok_f})
+ write_log(args.out, log)
+ print(f"log -> {args.out}")
+
+
+if __name__ == "__main__":
+ main()