summaryrefslogtreecommitdiff
path: root/research/flossing/step9_trajectory_perturb_train.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/step9_trajectory_perturb_train.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/step9_trajectory_perturb_train.py')
-rw-r--r--research/flossing/step9_trajectory_perturb_train.py595
1 files changed, 595 insertions, 0 deletions
diff --git a/research/flossing/step9_trajectory_perturb_train.py b/research/flossing/step9_trajectory_perturb_train.py
new file mode 100644
index 0000000..61cbf04
--- /dev/null
+++ b/research/flossing/step9_trajectory_perturb_train.py
@@ -0,0 +1,595 @@
+"""Step 9: Supervised training with Lyapunov-style initial trajectory perturbations.
+
+This is hidden-trajectory augmentation, not a Lyapunov/flossing objective:
+
+ single_perturbed_ce:
+ sample one tiny perturbation of the initial recursive state and train with
+ the original supervised ACT loss on the same (x, y).
+
+ multi_perturbed_ce:
+ run one clean trajectory plus K-1 independently perturbed trajectories on
+ the same (x, y), average the original supervised ACT losses, and update.
+
+The goal is to enlarge the correct answer basin around the model's nominal
+initial latent state without directly optimizing Lyapunov exponents or KL
+consistency to the model's current answer.
+"""
+from __future__ import annotations
+
+import argparse
+import json
+import math
+import sys
+import time
+from dataclasses import replace
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F
+
+FLOSS_DIR = Path(__file__).resolve().parent
+sys.path.insert(0, str(FLOSS_DIR))
+
+from step7_interfloss import ( # noqa: E402
+ evaluate,
+ freeze_puzzle_embedding,
+ load_model,
+ load_train_batches,
+ move_batch,
+ write_log,
+)
+
+IGNORE_LABEL_ID = -100
+
+
+def _randn_like(tensor: torch.Tensor, generator: torch.Generator) -> torch.Tensor:
+ noise = torch.randn(
+ tensor.shape,
+ device=tensor.device,
+ dtype=torch.float32,
+ generator=generator,
+ )
+ return noise.to(tensor.dtype)
+
+
+def _unit_noise_like(tensor: torch.Tensor, generator: torch.Generator, sampling: str) -> torch.Tensor:
+ if sampling == "uniform":
+ # Match Normal(0, 1) variance: U[-sqrt(3), sqrt(3)] has std=1.
+ noise = torch.rand(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator)
+ noise = (2.0 * noise - 1.0) * math.sqrt(3.0)
+ else:
+ noise = torch.randn(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator)
+ return noise.to(tensor.dtype)
+
+
+def _expand_first_dim(batch: dict[str, torch.Tensor], n: int) -> dict[str, torch.Tensor]:
+ if n == 1:
+ return batch
+ return {k: v.repeat_interleave(n, dim=0) for k, v in batch.items()}
+
+
+def _noise_target_std(args, train_step: int | None) -> float:
+ if train_step is None or args.sigma_ramp_steps <= 0:
+ return args.noise_std
+ start = args.sigma_start if args.sigma_start is not None else args.noise_std
+ frac = min(max(train_step / args.sigma_ramp_steps, 0.0), 1.0)
+ return float(start + frac * (args.noise_std - start))
+
+
+def _sample_noise_stds(
+ args,
+ batch_size: int,
+ n_trajectories: int,
+ device: str,
+ generator: torch.Generator,
+ train_step: int | None,
+) -> tuple[torch.Tensor, torch.Tensor, float]:
+ """Return per-row perturbation stds and a clean/noisy mask for expanded rows."""
+ total = batch_size * n_trajectories
+ std_target = _noise_target_std(args, train_step)
+ stds = torch.zeros(total, device=device, dtype=torch.float32)
+
+ if args.mode == "baseline_clean":
+ noisy_mask = torch.zeros(total, device=device, dtype=torch.bool)
+ return stds, noisy_mask, std_target
+
+ rollout_id = torch.arange(total, device=device) % n_trajectories
+ noisy_mask = torch.ones(total, device=device, dtype=torch.bool)
+ if args.mode == "multi_perturbed_ce":
+ noisy_mask = rollout_id > 0
+
+ if std_target <= 0:
+ return stds, noisy_mask, std_target
+
+ active_count = int(noisy_mask.sum().item())
+ if active_count == 0:
+ return stds, noisy_mask, std_target
+
+ if args.noise_sampling == "loguniform":
+ ramp_scale = 1.0 if args.noise_std <= 0 else std_target / args.noise_std
+ final_hi = args.noise_max if args.noise_max is not None else args.noise_std
+ final_lo = args.noise_min if args.noise_min is not None else max(final_hi / 10.0, 1e-8)
+ lo = float(final_lo) * ramp_scale
+ hi = float(final_hi) * ramp_scale
+ if hi <= 0:
+ return stds, noisy_mask, std_target
+ lo = max(float(lo), 1e-12)
+ hi = max(float(hi), lo)
+ u = torch.rand(active_count, device=device, dtype=torch.float32, generator=generator)
+ sampled = torch.exp(math.log(lo) + u * (math.log(hi) - math.log(lo)))
+ else:
+ sampled = torch.full((active_count,), std_target, device=device, dtype=torch.float32)
+
+ if args.noise_sampling == "mixture_normal" and args.clean_prob > 0:
+ keep_noisy = torch.rand(active_count, device=device, dtype=torch.float32, generator=generator) >= args.clean_prob
+ sampled = torch.where(keep_noisy, sampled, torch.zeros_like(sampled))
+
+ stds[noisy_mask] = sampled
+ return stds, stds > 0, std_target
+
+
+def _add_state_noise(
+ inner,
+ noise_stds: torch.Tensor,
+ generator: torch.Generator | None,
+ perturb: str,
+ sampling: str,
+):
+ if noise_stds.numel() == 0 or float(noise_stds.max().item()) <= 0:
+ return inner
+ if generator is None:
+ raise ValueError("generator is required when noise is active")
+
+ view_shape = (noise_stds.shape[0],) + (1,) * (inner.z_H.ndim - 1)
+ scaled = noise_stds.view(view_shape)
+ z_h = inner.z_H
+ z_l = inner.z_L
+ if perturb in ("h", "both"):
+ z_h = z_h + scaled.to(z_h.dtype) * _unit_noise_like(z_h, generator, sampling)
+ if perturb in ("l", "both"):
+ z_l = z_l + scaled.to(z_l.dtype) * _unit_noise_like(z_l, generator, sampling)
+ return replace(inner, z_H=z_h, z_L=z_l)
+
+
+def make_loaded_initial_carry(
+ base,
+ batch: dict[str, torch.Tensor],
+ device: str,
+ noise_std: float,
+ generator: torch.Generator | None,
+ perturb: str,
+):
+ """Construct a first-step carry whose current_data is already loaded.
+
+ The ACT wrappers reset any sample with halted=True before the inner forward.
+ To perturb the actual initial recurrent state, we first apply the same reset
+ to H_init/L_init, then mark samples as not halted so the perturbation is not
+ overwritten on the first model call.
+ """
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+
+ reset_flag = torch.ones_like(carry.halted)
+ inner = base.inner.reset_carry(reset_flag, carry.inner_carry)
+
+ if noise_std > 0:
+ if generator is None:
+ raise ValueError("generator is required when noise_std > 0")
+ noise_stds = torch.full((batch["inputs"].shape[0],), noise_std, device=device, dtype=torch.float32)
+ inner = _add_state_noise(inner, noise_stds, generator, perturb, "normal")
+
+ return replace(
+ carry,
+ inner_carry=inner,
+ steps=torch.zeros_like(carry.steps),
+ halted=torch.zeros_like(carry.halted),
+ current_data={k: v for k, v in batch.items()},
+ )
+
+
+def branch_supervised_loss(head, base, batch, carry):
+ loss_sum = 0.0
+ n_loss = 0
+ for _ in range(base.config.halt_max_steps):
+ carry, loss, _metrics, _outputs, all_finish = head(return_keys=[], carry=carry, batch=batch)
+ loss_sum = loss_sum + loss
+ n_loss += 1
+ if all_finish:
+ break
+ return loss_sum / max(n_loss, 1) / batch["inputs"].shape[0]
+
+
+def _token_loss(loss_fn, logits, labels, mask):
+ kwargs = {"ignore_index": IGNORE_LABEL_ID}
+ code = getattr(loss_fn, "__code__", None)
+ arg_names = code.co_varnames[: code.co_argcount + code.co_kwonlyargcount] if code is not None else ()
+ if "valid_mask" in arg_names:
+ kwargs["valid_mask"] = mask
+ return loss_fn(logits, labels, **kwargs)
+
+
+def _step_supervised_loss_vec(head, base, inner_carry, batch, act_step: int):
+ new_inner, logits, (q_halt_logits, q_continue_logits) = base.inner(inner_carry, batch)
+ labels = batch["labels"]
+ with torch.no_grad():
+ mask = labels != IGNORE_LABEL_ID
+ loss_counts = mask.sum(-1)
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1)
+ is_correct = mask & (torch.argmax(logits, dim=-1) == labels)
+ seq_is_correct = is_correct.sum(-1) == loss_counts
+
+ lm_loss_vec = (_token_loss(head.loss_fn, logits, labels, mask) / loss_divisor).sum(-1)
+ q_halt_loss_vec = F.binary_cross_entropy_with_logits(
+ q_halt_logits,
+ seq_is_correct.to(q_halt_logits.dtype),
+ reduction="none",
+ )
+
+ q_continue_loss_vec = torch.zeros_like(q_halt_loss_vec)
+ if not getattr(base.config, "no_ACT_continue", False):
+ with torch.no_grad():
+ next_q_halt_logits, next_q_continue_logits = base.inner(new_inner, batch)[-1]
+ is_last_step = act_step + 1 >= base.config.halt_max_steps
+ target_q_continue = torch.sigmoid(
+ torch.where(
+ torch.full_like(next_q_halt_logits, is_last_step, dtype=torch.bool),
+ next_q_halt_logits,
+ torch.maximum(next_q_halt_logits, next_q_continue_logits),
+ )
+ )
+ q_continue_loss_vec = F.binary_cross_entropy_with_logits(
+ q_continue_logits,
+ target_q_continue,
+ reduction="none",
+ )
+
+ return new_inner, lm_loss_vec + 0.5 * (q_halt_loss_vec + q_continue_loss_vec)
+
+
+def fixed_unroll_supervised_loss(args, head, base, batch, device, generator, train_step: int | None):
+ batch_size = batch["inputs"].shape[0]
+ n_trajectories = args.n_trajectories if args.mode == "multi_perturbed_ce" else 1
+ expanded_batch = _expand_first_dim(batch, n_trajectories)
+ total = expanded_batch["inputs"].shape[0]
+
+ noise_stds, actual_noisy_mask, std_target = _sample_noise_stds(
+ args,
+ batch_size=batch_size,
+ n_trajectories=n_trajectories,
+ device=device,
+ generator=generator,
+ train_step=train_step,
+ )
+
+ with torch.device(device):
+ carry = base.initial_carry(expanded_batch)
+ reset_flag = torch.ones_like(carry.halted)
+ inner = base.inner.reset_carry(reset_flag, carry.inner_carry)
+ inner = _add_state_noise(inner, noise_stds, generator, args.perturb, args.noise_sampling)
+
+ loss_vec_sum = torch.zeros(total, device=device, dtype=torch.float32)
+ inner_carry = inner
+ for act_step in range(base.config.halt_max_steps):
+ inner_carry, step_loss_vec = _step_supervised_loss_vec(head, base, inner_carry, expanded_batch, act_step)
+ loss_vec_sum = loss_vec_sum + step_loss_vec.to(torch.float32)
+
+ per_seq = loss_vec_sum / max(base.config.halt_max_steps, 1)
+ loss = per_seq.mean()
+
+ if args.mode == "multi_perturbed_ce":
+ rollout_id = torch.arange(total, device=device) % n_trajectories
+ clean_mask = rollout_id == 0
+ noisy_slot_mask = rollout_id > 0
+ elif args.mode == "baseline_clean":
+ clean_mask = torch.ones(total, device=device, dtype=torch.bool)
+ noisy_slot_mask = torch.zeros(total, device=device, dtype=torch.bool)
+ else:
+ clean_mask = torch.zeros(total, device=device, dtype=torch.bool)
+ noisy_slot_mask = torch.ones(total, device=device, dtype=torch.bool)
+
+ clean_loss = per_seq.detach()[clean_mask].mean().item() if bool(clean_mask.any()) else 0.0
+ noisy_loss = per_seq.detach()[noisy_slot_mask].mean().item() if bool(noisy_slot_mask.any()) else 0.0
+ active_noise = noise_stds[actual_noisy_mask]
+ return loss, {
+ "clean_loss": float(clean_loss),
+ "noisy_loss_mean": float(noisy_loss),
+ "noise_std_target": float(std_target),
+ "noise_std_mean": float(active_noise.mean().item()) if active_noise.numel() else 0.0,
+ "noise_std_max": float(active_noise.max().item()) if active_noise.numel() else 0.0,
+ "effective_batch": int(total),
+ }
+
+
+def train_loss(args, head, base, batch, device, generator, train_step: int | None):
+ if args.rollout_impl == "parallel_fixed":
+ return fixed_unroll_supervised_loss(args, head, base, batch, device, generator, train_step)
+
+ if args.mode == "baseline_clean":
+ carry = make_loaded_initial_carry(
+ base,
+ batch,
+ device,
+ noise_std=0.0,
+ generator=None,
+ perturb=args.perturb,
+ )
+ loss = branch_supervised_loss(head, base, batch, carry)
+ return loss, {
+ "clean_loss": float(loss.detach().item()),
+ "noisy_loss_mean": 0.0,
+ "noise_std_target": 0.0,
+ "noise_std_mean": 0.0,
+ "noise_std_max": 0.0,
+ "effective_batch": int(batch["inputs"].shape[0]),
+ }
+
+ if args.mode == "single_perturbed_ce":
+ std_target = _noise_target_std(args, train_step)
+ carry = make_loaded_initial_carry(
+ base,
+ batch,
+ device,
+ noise_std=std_target,
+ generator=generator,
+ perturb=args.perturb,
+ )
+ loss = branch_supervised_loss(head, base, batch, carry)
+ return loss, {
+ "clean_loss": 0.0,
+ "noisy_loss_mean": float(loss.detach().item()),
+ "noise_std_target": float(std_target),
+ "noise_std_mean": float(std_target),
+ "noise_std_max": float(std_target),
+ "effective_batch": int(batch["inputs"].shape[0]),
+ }
+
+ if args.mode == "multi_perturbed_ce":
+ if args.n_trajectories < 2:
+ raise ValueError("multi_perturbed_ce requires --n-trajectories >= 2")
+ std_target = _noise_target_std(args, train_step)
+
+ clean_carry = make_loaded_initial_carry(
+ base,
+ batch,
+ device,
+ noise_std=0.0,
+ generator=None,
+ perturb=args.perturb,
+ )
+ clean_loss = branch_supervised_loss(head, base, batch, clean_carry)
+ losses = [clean_loss]
+ noisy_vals = []
+ for _ in range(args.n_trajectories - 1):
+ noisy_carry = make_loaded_initial_carry(
+ base,
+ batch,
+ device,
+ noise_std=std_target,
+ generator=generator,
+ perturb=args.perturb,
+ )
+ noisy_loss = branch_supervised_loss(head, base, batch, noisy_carry)
+ losses.append(noisy_loss)
+ noisy_vals.append(float(noisy_loss.detach().item()))
+
+ total = torch.stack(losses).mean()
+ return total, {
+ "clean_loss": float(clean_loss.detach().item()),
+ "noisy_loss_mean": sum(noisy_vals) / max(len(noisy_vals), 1),
+ "noise_std_target": float(std_target),
+ "noise_std_mean": float(std_target),
+ "noise_std_max": float(std_target),
+ "effective_batch": int(batch["inputs"].shape[0] * args.n_trajectories),
+ }
+
+ raise ValueError(f"unknown mode: {args.mode}")
+
+
+def save_checkpoint(head, save_dir: Path, name: str):
+ save_dir.mkdir(parents=True, exist_ok=True)
+ path = save_dir / name
+ torch.save(head.state_dict(), path)
+ return str(path)
+
+
+def save_training_state(
+ head,
+ optim,
+ generator: torch.Generator,
+ args,
+ save_dir: Path,
+ name: str,
+ train_step: int,
+ best_acc: float,
+ best_step: int,
+):
+ save_dir.mkdir(parents=True, exist_ok=True)
+ path = save_dir / name
+ state = {
+ "format": "step9_training_state_v1",
+ "model_state_dict": head.state_dict(),
+ "optimizer_state_dict": optim.state_dict(),
+ "train_step": int(train_step),
+ "best_acc": float(best_acc),
+ "best_step": int(best_step),
+ "args": vars(args),
+ "torch_rng_state": torch.get_rng_state(),
+ "noise_generator_state": generator.get_state(),
+ }
+ if torch.cuda.is_available():
+ state["cuda_rng_state"] = torch.cuda.get_rng_state()
+ torch.save(state, path)
+ return str(path)
+
+
+def load_training_state(path: Path, head, optim, generator: torch.Generator, device: str):
+ state = torch.load(path, map_location=device, weights_only=False)
+ if "model_state_dict" not in state:
+ raise ValueError(f"{path} is not a step9 training-state checkpoint")
+ missing, unexpected = head.load_state_dict(state["model_state_dict"], strict=False)
+ print(f"[resume {path}] missing={len(missing)} unexpected={len(unexpected)}")
+ if "optimizer_state_dict" in state:
+ optim.load_state_dict(state["optimizer_state_dict"])
+ if "torch_rng_state" in state:
+ torch.set_rng_state(state["torch_rng_state"].cpu())
+ if "cuda_rng_state" in state and torch.cuda.is_available():
+ torch.cuda.set_rng_state(state["cuda_rng_state"].cpu())
+ if "noise_generator_state" in state:
+ generator.set_state(state["noise_generator_state"].cpu())
+ return {
+ "train_step": int(state.get("train_step", 0)),
+ "best_acc": state.get("best_acc"),
+ "best_step": int(state.get("best_step", 0)),
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", choices=["hrm", "trm"], required=True)
+ parser.add_argument("--ckpt-root", required=True)
+ parser.add_argument("--ckpt-name", required=True)
+ parser.add_argument(
+ "--mode",
+ choices=["baseline_clean", "single_perturbed_ce", "multi_perturbed_ce"],
+ required=True,
+ )
+ parser.add_argument("--train-steps", type=int, default=10000)
+ parser.add_argument("--batch-size", type=int, default=8)
+ parser.add_argument("--lr", type=float, default=1e-5)
+ parser.add_argument("--noise-std", type=float, default=1e-3)
+ parser.add_argument("--noise-min", type=float, default=None)
+ parser.add_argument("--noise-max", type=float, default=None)
+ parser.add_argument(
+ "--noise-sampling",
+ choices=["normal", "uniform", "loguniform", "mixture_normal"],
+ default="normal",
+ )
+ parser.add_argument("--clean-prob", type=float, default=0.0)
+ parser.add_argument("--sigma-start", type=float, default=None)
+ parser.add_argument("--sigma-ramp-steps", type=int, default=0)
+ parser.add_argument("--n-trajectories", type=int, default=4)
+ parser.add_argument("--rollout-impl", choices=["serial_act", "parallel_fixed"], default="parallel_fixed")
+ parser.add_argument("--perturb", choices=["h", "l", "both"], default="both")
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--eval-every", type=int, default=1000)
+ parser.add_argument("--eval-n", type=int, default=512)
+ parser.add_argument("--eval-batch-size", type=int, default=32)
+ parser.add_argument("--out", default="step9_trajectory_perturb_log.json")
+ parser.add_argument("--save-dir", default=None)
+ parser.add_argument("--save-best", action="store_true")
+ parser.add_argument("--save-final", action="store_true")
+ parser.add_argument("--save-every-eval", action="store_true")
+ parser.add_argument("--save-train-state", action="store_true")
+ parser.add_argument("--resume-state", default=None)
+ args = parser.parse_args()
+
+ torch.manual_seed(args.seed)
+ device = "cuda"
+ head, base, cfg, adam_cls = load_model(args.model, Path(args.ckpt_root), args.ckpt_name, device)
+ data_path = Path(cfg["data_path"])
+ optim = adam_cls(head.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=cfg["weight_decay"])
+ generator = torch.Generator(device=device).manual_seed(args.seed + 900000)
+ resume_step = 0
+ resume_best_acc = None
+ resume_best_step = 0
+ if args.resume_state:
+ resume_info = load_training_state(Path(args.resume_state), head, optim, generator, device)
+ resume_step = resume_info["train_step"]
+ resume_best_acc = resume_info["best_acc"]
+ resume_best_step = resume_info["best_step"]
+ print(f"[resume] train_step={resume_step} best_acc={resume_best_acc} best_step={resume_best_step}", flush=True)
+
+ print(f"\n=== Initial eval (loaded {args.ckpt_name}) ===")
+ acc0, tok0 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f" initial: exact_acc={acc0:.4f} token_acc={tok0:.4f}", flush=True)
+
+ log = {
+ "args": vars(args),
+ "initial_acc": acc0,
+ "initial_tok_acc": tok0,
+ "steps": [],
+ "evals": [{"kind": "initial", "train_step": 0, "acc": acc0, "tok_acc": tok0}],
+ "checkpoints": [],
+ "resume_state": args.resume_state,
+ "resume_step": resume_step,
+ }
+ write_log(args.out, log)
+ save_dir = Path(args.save_dir) if args.save_dir else Path(str(args.out)).with_suffix("").with_name(Path(str(args.out)).with_suffix("").name + "_ckpts")
+ best_acc = float(resume_best_acc) if resume_best_acc is not None else acc0
+ best_step = resume_best_step if resume_best_acc is not None else 0
+
+ train_iter = load_train_batches(data_path, args.batch_size, args.train_steps, seed=args.seed)
+ for _ in range(min(resume_step, args.train_steps)):
+ next(train_iter)
+ t0 = time.time()
+ for train_step, batch_cpu in enumerate(train_iter, start=resume_step):
+ batch = move_batch(batch_cpu, device)
+ head.train()
+ freeze_puzzle_embedding(base)
+ optim.zero_grad(set_to_none=True)
+ loss, parts = train_loss(args, head, base, batch, device, generator, train_step + 1)
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0)
+ optim.step()
+
+ rec = {
+ "train_step": train_step + 1,
+ "loss": float(loss.detach().item()),
+ **parts,
+ }
+ log["steps"].append(rec)
+ if train_step % 50 == 0 or train_step == args.train_steps - 1:
+ print(
+ f" T[{train_step + 1:>5}/{args.train_steps}] dt={time.time() - t0:.1f}s "
+ f"loss={rec['loss']:.4f} clean={rec['clean_loss']:.4f} "
+ f"noisy={rec['noisy_loss_mean']:.4f} "
+ f"sigma={rec['noise_std_mean']:.2e}/{rec['noise_std_max']:.2e} "
+ f"effB={rec['effective_batch']}",
+ flush=True,
+ )
+
+ if (train_step + 1) % args.eval_every == 0:
+ acc, tok_acc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f" >> EVAL @ step {train_step + 1}: exact_acc={acc:.4f} delta={acc - acc0:+.4f}", flush=True)
+ log["evals"].append({"kind": "task", "train_step": train_step + 1, "acc": acc, "tok_acc": tok_acc})
+ if args.save_every_eval:
+ ckpt_path = save_checkpoint(head, save_dir, f"step_{train_step + 1}.pt")
+ log["checkpoints"].append({"kind": "eval", "train_step": train_step + 1, "acc": acc, "path": ckpt_path})
+ if args.save_best and acc >= best_acc:
+ best_acc = acc
+ best_step = train_step + 1
+ ckpt_path = save_checkpoint(head, save_dir, "best.pt")
+ log["best_acc"] = best_acc
+ log["best_step"] = best_step
+ log["best_checkpoint"] = ckpt_path
+ log["checkpoints"].append({"kind": "best", "train_step": train_step + 1, "acc": acc, "path": ckpt_path})
+ if args.save_train_state:
+ state_path = save_training_state(head, optim, generator, args, save_dir, "best_state.pt", train_step + 1, best_acc, best_step)
+ log["best_state_checkpoint"] = state_path
+ log["checkpoints"].append({"kind": "best_state", "train_step": train_step + 1, "acc": acc, "path": state_path})
+ if args.save_train_state:
+ state_path = save_training_state(head, optim, generator, args, save_dir, "latest_state.pt", train_step + 1, best_acc, best_step)
+ log["latest_state_checkpoint"] = state_path
+ log["checkpoints"].append({"kind": "latest_state", "train_step": train_step + 1, "acc": acc, "path": state_path})
+ write_log(args.out, log)
+
+ acc_f, tok_f = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print("\n=== Final eval ===")
+ print(f" initial={acc0:.4f} final={acc_f:.4f} delta={acc_f - acc0:+.4f}", flush=True)
+ log["final_acc"] = acc_f
+ log["final_tok_acc"] = tok_f
+ log["evals"].append({"kind": "final", "train_step": args.train_steps, "acc": acc_f, "tok_acc": tok_f})
+ if args.save_final:
+ ckpt_path = save_checkpoint(head, save_dir, "final.pt")
+ log["final_checkpoint"] = ckpt_path
+ log["checkpoints"].append({"kind": "final", "train_step": args.train_steps, "acc": acc_f, "path": ckpt_path})
+ if args.save_train_state:
+ state_path = save_training_state(head, optim, generator, args, save_dir, "final_state.pt", args.train_steps, best_acc, best_step)
+ log["final_state_checkpoint"] = state_path
+ log["checkpoints"].append({"kind": "final_state", "train_step": args.train_steps, "acc": acc_f, "path": state_path})
+ write_log(args.out, log)
+ print(f"log -> {args.out}")
+
+
+if __name__ == "__main__":
+ main()