"""Step 9: Supervised training with Lyapunov-style initial trajectory perturbations. This is hidden-trajectory augmentation, not a Lyapunov/flossing objective: single_perturbed_ce: sample one tiny perturbation of the initial recursive state and train with the original supervised ACT loss on the same (x, y). multi_perturbed_ce: run one clean trajectory plus K-1 independently perturbed trajectories on the same (x, y), average the original supervised ACT losses, and update. The goal is to enlarge the correct answer basin around the model's nominal initial latent state without directly optimizing Lyapunov exponents or KL consistency to the model's current answer. """ from __future__ import annotations import argparse import json import math import sys import time from dataclasses import replace from pathlib import Path import torch import torch.nn.functional as F FLOSS_DIR = Path(__file__).resolve().parent sys.path.insert(0, str(FLOSS_DIR)) from step7_interfloss import ( # noqa: E402 evaluate, freeze_puzzle_embedding, load_model, load_train_batches, move_batch, write_log, ) IGNORE_LABEL_ID = -100 def _randn_like(tensor: torch.Tensor, generator: torch.Generator) -> torch.Tensor: noise = torch.randn( tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator, ) return noise.to(tensor.dtype) def _unit_noise_like(tensor: torch.Tensor, generator: torch.Generator, sampling: str) -> torch.Tensor: if sampling == "uniform": # Match Normal(0, 1) variance: U[-sqrt(3), sqrt(3)] has std=1. noise = torch.rand(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) noise = (2.0 * noise - 1.0) * math.sqrt(3.0) else: noise = torch.randn(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) return noise.to(tensor.dtype) def _expand_first_dim(batch: dict[str, torch.Tensor], n: int) -> dict[str, torch.Tensor]: if n == 1: return batch return {k: v.repeat_interleave(n, dim=0) for k, v in batch.items()} def _noise_target_std(args, train_step: int | None) -> float: if train_step is None or args.sigma_ramp_steps <= 0: return args.noise_std start = args.sigma_start if args.sigma_start is not None else args.noise_std frac = min(max(train_step / args.sigma_ramp_steps, 0.0), 1.0) return float(start + frac * (args.noise_std - start)) def _sample_noise_stds( args, batch_size: int, n_trajectories: int, device: str, generator: torch.Generator, train_step: int | None, ) -> tuple[torch.Tensor, torch.Tensor, float]: """Return per-row perturbation stds and a clean/noisy mask for expanded rows.""" total = batch_size * n_trajectories std_target = _noise_target_std(args, train_step) stds = torch.zeros(total, device=device, dtype=torch.float32) if args.mode == "baseline_clean": noisy_mask = torch.zeros(total, device=device, dtype=torch.bool) return stds, noisy_mask, std_target rollout_id = torch.arange(total, device=device) % n_trajectories noisy_mask = torch.ones(total, device=device, dtype=torch.bool) if args.mode == "multi_perturbed_ce": noisy_mask = rollout_id > 0 if std_target <= 0: return stds, noisy_mask, std_target active_count = int(noisy_mask.sum().item()) if active_count == 0: return stds, noisy_mask, std_target if args.noise_sampling == "loguniform": ramp_scale = 1.0 if args.noise_std <= 0 else std_target / args.noise_std final_hi = args.noise_max if args.noise_max is not None else args.noise_std final_lo = args.noise_min if args.noise_min is not None else max(final_hi / 10.0, 1e-8) lo = float(final_lo) * ramp_scale hi = float(final_hi) * ramp_scale if hi <= 0: return stds, noisy_mask, std_target lo = max(float(lo), 1e-12) hi = max(float(hi), lo) u = torch.rand(active_count, device=device, dtype=torch.float32, generator=generator) sampled = torch.exp(math.log(lo) + u * (math.log(hi) - math.log(lo))) else: sampled = torch.full((active_count,), std_target, device=device, dtype=torch.float32) if args.noise_sampling == "mixture_normal" and args.clean_prob > 0: keep_noisy = torch.rand(active_count, device=device, dtype=torch.float32, generator=generator) >= args.clean_prob sampled = torch.where(keep_noisy, sampled, torch.zeros_like(sampled)) stds[noisy_mask] = sampled return stds, stds > 0, std_target def _add_state_noise( inner, noise_stds: torch.Tensor, generator: torch.Generator | None, perturb: str, sampling: str, ): if noise_stds.numel() == 0 or float(noise_stds.max().item()) <= 0: return inner if generator is None: raise ValueError("generator is required when noise is active") view_shape = (noise_stds.shape[0],) + (1,) * (inner.z_H.ndim - 1) scaled = noise_stds.view(view_shape) z_h = inner.z_H z_l = inner.z_L if perturb in ("h", "both"): z_h = z_h + scaled.to(z_h.dtype) * _unit_noise_like(z_h, generator, sampling) if perturb in ("l", "both"): z_l = z_l + scaled.to(z_l.dtype) * _unit_noise_like(z_l, generator, sampling) return replace(inner, z_H=z_h, z_L=z_l) def make_loaded_initial_carry( base, batch: dict[str, torch.Tensor], device: str, noise_std: float, generator: torch.Generator | None, perturb: str, ): """Construct a first-step carry whose current_data is already loaded. The ACT wrappers reset any sample with halted=True before the inner forward. To perturb the actual initial recurrent state, we first apply the same reset to H_init/L_init, then mark samples as not halted so the perturbation is not overwritten on the first model call. """ with torch.device(device): carry = base.initial_carry(batch) reset_flag = torch.ones_like(carry.halted) inner = base.inner.reset_carry(reset_flag, carry.inner_carry) if noise_std > 0: if generator is None: raise ValueError("generator is required when noise_std > 0") noise_stds = torch.full((batch["inputs"].shape[0],), noise_std, device=device, dtype=torch.float32) inner = _add_state_noise(inner, noise_stds, generator, perturb, "normal") return replace( carry, inner_carry=inner, steps=torch.zeros_like(carry.steps), halted=torch.zeros_like(carry.halted), current_data={k: v for k, v in batch.items()}, ) def branch_supervised_loss(head, base, batch, carry): loss_sum = 0.0 n_loss = 0 for _ in range(base.config.halt_max_steps): carry, loss, _metrics, _outputs, all_finish = head(return_keys=[], carry=carry, batch=batch) loss_sum = loss_sum + loss n_loss += 1 if all_finish: break return loss_sum / max(n_loss, 1) / batch["inputs"].shape[0] def _token_loss(loss_fn, logits, labels, mask): kwargs = {"ignore_index": IGNORE_LABEL_ID} code = getattr(loss_fn, "__code__", None) arg_names = code.co_varnames[: code.co_argcount + code.co_kwonlyargcount] if code is not None else () if "valid_mask" in arg_names: kwargs["valid_mask"] = mask return loss_fn(logits, labels, **kwargs) def _step_supervised_loss_vec(head, base, inner_carry, batch, act_step: int): new_inner, logits, (q_halt_logits, q_continue_logits) = base.inner(inner_carry, batch) labels = batch["labels"] with torch.no_grad(): mask = labels != IGNORE_LABEL_ID loss_counts = mask.sum(-1) loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) is_correct = mask & (torch.argmax(logits, dim=-1) == labels) seq_is_correct = is_correct.sum(-1) == loss_counts lm_loss_vec = (_token_loss(head.loss_fn, logits, labels, mask) / loss_divisor).sum(-1) q_halt_loss_vec = F.binary_cross_entropy_with_logits( q_halt_logits, seq_is_correct.to(q_halt_logits.dtype), reduction="none", ) q_continue_loss_vec = torch.zeros_like(q_halt_loss_vec) if not getattr(base.config, "no_ACT_continue", False): with torch.no_grad(): next_q_halt_logits, next_q_continue_logits = base.inner(new_inner, batch)[-1] is_last_step = act_step + 1 >= base.config.halt_max_steps target_q_continue = torch.sigmoid( torch.where( torch.full_like(next_q_halt_logits, is_last_step, dtype=torch.bool), next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits), ) ) q_continue_loss_vec = F.binary_cross_entropy_with_logits( q_continue_logits, target_q_continue, reduction="none", ) return new_inner, lm_loss_vec + 0.5 * (q_halt_loss_vec + q_continue_loss_vec) def fixed_unroll_supervised_loss(args, head, base, batch, device, generator, train_step: int | None): batch_size = batch["inputs"].shape[0] n_trajectories = args.n_trajectories if args.mode == "multi_perturbed_ce" else 1 expanded_batch = _expand_first_dim(batch, n_trajectories) total = expanded_batch["inputs"].shape[0] noise_stds, actual_noisy_mask, std_target = _sample_noise_stds( args, batch_size=batch_size, n_trajectories=n_trajectories, device=device, generator=generator, train_step=train_step, ) with torch.device(device): carry = base.initial_carry(expanded_batch) reset_flag = torch.ones_like(carry.halted) inner = base.inner.reset_carry(reset_flag, carry.inner_carry) inner = _add_state_noise(inner, noise_stds, generator, args.perturb, args.noise_sampling) loss_vec_sum = torch.zeros(total, device=device, dtype=torch.float32) inner_carry = inner for act_step in range(base.config.halt_max_steps): inner_carry, step_loss_vec = _step_supervised_loss_vec(head, base, inner_carry, expanded_batch, act_step) loss_vec_sum = loss_vec_sum + step_loss_vec.to(torch.float32) per_seq = loss_vec_sum / max(base.config.halt_max_steps, 1) loss = per_seq.mean() if args.mode == "multi_perturbed_ce": rollout_id = torch.arange(total, device=device) % n_trajectories clean_mask = rollout_id == 0 noisy_slot_mask = rollout_id > 0 elif args.mode == "baseline_clean": clean_mask = torch.ones(total, device=device, dtype=torch.bool) noisy_slot_mask = torch.zeros(total, device=device, dtype=torch.bool) else: clean_mask = torch.zeros(total, device=device, dtype=torch.bool) noisy_slot_mask = torch.ones(total, device=device, dtype=torch.bool) clean_loss = per_seq.detach()[clean_mask].mean().item() if bool(clean_mask.any()) else 0.0 noisy_loss = per_seq.detach()[noisy_slot_mask].mean().item() if bool(noisy_slot_mask.any()) else 0.0 active_noise = noise_stds[actual_noisy_mask] return loss, { "clean_loss": float(clean_loss), "noisy_loss_mean": float(noisy_loss), "noise_std_target": float(std_target), "noise_std_mean": float(active_noise.mean().item()) if active_noise.numel() else 0.0, "noise_std_max": float(active_noise.max().item()) if active_noise.numel() else 0.0, "effective_batch": int(total), } def train_loss(args, head, base, batch, device, generator, train_step: int | None): if args.rollout_impl == "parallel_fixed": return fixed_unroll_supervised_loss(args, head, base, batch, device, generator, train_step) if args.mode == "baseline_clean": carry = make_loaded_initial_carry( base, batch, device, noise_std=0.0, generator=None, perturb=args.perturb, ) loss = branch_supervised_loss(head, base, batch, carry) return loss, { "clean_loss": float(loss.detach().item()), "noisy_loss_mean": 0.0, "noise_std_target": 0.0, "noise_std_mean": 0.0, "noise_std_max": 0.0, "effective_batch": int(batch["inputs"].shape[0]), } if args.mode == "single_perturbed_ce": std_target = _noise_target_std(args, train_step) carry = make_loaded_initial_carry( base, batch, device, noise_std=std_target, generator=generator, perturb=args.perturb, ) loss = branch_supervised_loss(head, base, batch, carry) return loss, { "clean_loss": 0.0, "noisy_loss_mean": float(loss.detach().item()), "noise_std_target": float(std_target), "noise_std_mean": float(std_target), "noise_std_max": float(std_target), "effective_batch": int(batch["inputs"].shape[0]), } if args.mode == "multi_perturbed_ce": if args.n_trajectories < 2: raise ValueError("multi_perturbed_ce requires --n-trajectories >= 2") std_target = _noise_target_std(args, train_step) clean_carry = make_loaded_initial_carry( base, batch, device, noise_std=0.0, generator=None, perturb=args.perturb, ) clean_loss = branch_supervised_loss(head, base, batch, clean_carry) losses = [clean_loss] noisy_vals = [] for _ in range(args.n_trajectories - 1): noisy_carry = make_loaded_initial_carry( base, batch, device, noise_std=std_target, generator=generator, perturb=args.perturb, ) noisy_loss = branch_supervised_loss(head, base, batch, noisy_carry) losses.append(noisy_loss) noisy_vals.append(float(noisy_loss.detach().item())) total = torch.stack(losses).mean() return total, { "clean_loss": float(clean_loss.detach().item()), "noisy_loss_mean": sum(noisy_vals) / max(len(noisy_vals), 1), "noise_std_target": float(std_target), "noise_std_mean": float(std_target), "noise_std_max": float(std_target), "effective_batch": int(batch["inputs"].shape[0] * args.n_trajectories), } raise ValueError(f"unknown mode: {args.mode}") def save_checkpoint(head, save_dir: Path, name: str): save_dir.mkdir(parents=True, exist_ok=True) path = save_dir / name torch.save(head.state_dict(), path) return str(path) def save_training_state( head, optim, generator: torch.Generator, args, save_dir: Path, name: str, train_step: int, best_acc: float, best_step: int, ): save_dir.mkdir(parents=True, exist_ok=True) path = save_dir / name state = { "format": "step9_training_state_v1", "model_state_dict": head.state_dict(), "optimizer_state_dict": optim.state_dict(), "train_step": int(train_step), "best_acc": float(best_acc), "best_step": int(best_step), "args": vars(args), "torch_rng_state": torch.get_rng_state(), "noise_generator_state": generator.get_state(), } if torch.cuda.is_available(): state["cuda_rng_state"] = torch.cuda.get_rng_state() torch.save(state, path) return str(path) def load_training_state(path: Path, head, optim, generator: torch.Generator, device: str): state = torch.load(path, map_location=device, weights_only=False) if "model_state_dict" not in state: raise ValueError(f"{path} is not a step9 training-state checkpoint") missing, unexpected = head.load_state_dict(state["model_state_dict"], strict=False) print(f"[resume {path}] missing={len(missing)} unexpected={len(unexpected)}") if "optimizer_state_dict" in state: optim.load_state_dict(state["optimizer_state_dict"]) if "torch_rng_state" in state: torch.set_rng_state(state["torch_rng_state"].cpu()) if "cuda_rng_state" in state and torch.cuda.is_available(): torch.cuda.set_rng_state(state["cuda_rng_state"].cpu()) if "noise_generator_state" in state: generator.set_state(state["noise_generator_state"].cpu()) return { "train_step": int(state.get("train_step", 0)), "best_acc": state.get("best_acc"), "best_step": int(state.get("best_step", 0)), } def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", choices=["hrm", "trm"], required=True) parser.add_argument("--ckpt-root", required=True) parser.add_argument("--ckpt-name", required=True) parser.add_argument( "--mode", choices=["baseline_clean", "single_perturbed_ce", "multi_perturbed_ce"], required=True, ) parser.add_argument("--train-steps", type=int, default=10000) parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--noise-std", type=float, default=1e-3) parser.add_argument("--noise-min", type=float, default=None) parser.add_argument("--noise-max", type=float, default=None) parser.add_argument( "--noise-sampling", choices=["normal", "uniform", "loguniform", "mixture_normal"], default="normal", ) parser.add_argument("--clean-prob", type=float, default=0.0) parser.add_argument("--sigma-start", type=float, default=None) parser.add_argument("--sigma-ramp-steps", type=int, default=0) parser.add_argument("--n-trajectories", type=int, default=4) parser.add_argument("--rollout-impl", choices=["serial_act", "parallel_fixed"], default="parallel_fixed") parser.add_argument("--perturb", choices=["h", "l", "both"], default="both") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--eval-every", type=int, default=1000) parser.add_argument("--eval-n", type=int, default=512) parser.add_argument("--eval-batch-size", type=int, default=32) parser.add_argument("--out", default="step9_trajectory_perturb_log.json") parser.add_argument("--save-dir", default=None) parser.add_argument("--save-best", action="store_true") parser.add_argument("--save-final", action="store_true") parser.add_argument("--save-every-eval", action="store_true") parser.add_argument("--save-train-state", action="store_true") parser.add_argument("--resume-state", default=None) args = parser.parse_args() torch.manual_seed(args.seed) device = "cuda" head, base, cfg, adam_cls = load_model(args.model, Path(args.ckpt_root), args.ckpt_name, device) data_path = Path(cfg["data_path"]) optim = adam_cls(head.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=cfg["weight_decay"]) generator = torch.Generator(device=device).manual_seed(args.seed + 900000) resume_step = 0 resume_best_acc = None resume_best_step = 0 if args.resume_state: resume_info = load_training_state(Path(args.resume_state), head, optim, generator, device) resume_step = resume_info["train_step"] resume_best_acc = resume_info["best_acc"] resume_best_step = resume_info["best_step"] print(f"[resume] train_step={resume_step} best_acc={resume_best_acc} best_step={resume_best_step}", flush=True) print(f"\n=== Initial eval (loaded {args.ckpt_name}) ===") acc0, tok0 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) print(f" initial: exact_acc={acc0:.4f} token_acc={tok0:.4f}", flush=True) log = { "args": vars(args), "initial_acc": acc0, "initial_tok_acc": tok0, "steps": [], "evals": [{"kind": "initial", "train_step": 0, "acc": acc0, "tok_acc": tok0}], "checkpoints": [], "resume_state": args.resume_state, "resume_step": resume_step, } write_log(args.out, log) save_dir = Path(args.save_dir) if args.save_dir else Path(str(args.out)).with_suffix("").with_name(Path(str(args.out)).with_suffix("").name + "_ckpts") best_acc = float(resume_best_acc) if resume_best_acc is not None else acc0 best_step = resume_best_step if resume_best_acc is not None else 0 train_iter = load_train_batches(data_path, args.batch_size, args.train_steps, seed=args.seed) for _ in range(min(resume_step, args.train_steps)): next(train_iter) t0 = time.time() for train_step, batch_cpu in enumerate(train_iter, start=resume_step): batch = move_batch(batch_cpu, device) head.train() freeze_puzzle_embedding(base) optim.zero_grad(set_to_none=True) loss, parts = train_loss(args, head, base, batch, device, generator, train_step + 1) loss.backward() torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0) optim.step() rec = { "train_step": train_step + 1, "loss": float(loss.detach().item()), **parts, } log["steps"].append(rec) if train_step % 50 == 0 or train_step == args.train_steps - 1: print( f" T[{train_step + 1:>5}/{args.train_steps}] dt={time.time() - t0:.1f}s " f"loss={rec['loss']:.4f} clean={rec['clean_loss']:.4f} " f"noisy={rec['noisy_loss_mean']:.4f} " f"sigma={rec['noise_std_mean']:.2e}/{rec['noise_std_max']:.2e} " f"effB={rec['effective_batch']}", flush=True, ) if (train_step + 1) % args.eval_every == 0: acc, tok_acc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) print(f" >> EVAL @ step {train_step + 1}: exact_acc={acc:.4f} delta={acc - acc0:+.4f}", flush=True) log["evals"].append({"kind": "task", "train_step": train_step + 1, "acc": acc, "tok_acc": tok_acc}) if args.save_every_eval: ckpt_path = save_checkpoint(head, save_dir, f"step_{train_step + 1}.pt") log["checkpoints"].append({"kind": "eval", "train_step": train_step + 1, "acc": acc, "path": ckpt_path}) if args.save_best and acc >= best_acc: best_acc = acc best_step = train_step + 1 ckpt_path = save_checkpoint(head, save_dir, "best.pt") log["best_acc"] = best_acc log["best_step"] = best_step log["best_checkpoint"] = ckpt_path log["checkpoints"].append({"kind": "best", "train_step": train_step + 1, "acc": acc, "path": ckpt_path}) if args.save_train_state: state_path = save_training_state(head, optim, generator, args, save_dir, "best_state.pt", train_step + 1, best_acc, best_step) log["best_state_checkpoint"] = state_path log["checkpoints"].append({"kind": "best_state", "train_step": train_step + 1, "acc": acc, "path": state_path}) if args.save_train_state: state_path = save_training_state(head, optim, generator, args, save_dir, "latest_state.pt", train_step + 1, best_acc, best_step) log["latest_state_checkpoint"] = state_path log["checkpoints"].append({"kind": "latest_state", "train_step": train_step + 1, "acc": acc, "path": state_path}) write_log(args.out, log) acc_f, tok_f = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) print("\n=== Final eval ===") print(f" initial={acc0:.4f} final={acc_f:.4f} delta={acc_f - acc0:+.4f}", flush=True) log["final_acc"] = acc_f log["final_tok_acc"] = tok_f log["evals"].append({"kind": "final", "train_step": args.train_steps, "acc": acc_f, "tok_acc": tok_f}) if args.save_final: ckpt_path = save_checkpoint(head, save_dir, "final.pt") log["final_checkpoint"] = ckpt_path log["checkpoints"].append({"kind": "final", "train_step": args.train_steps, "acc": acc_f, "path": ckpt_path}) if args.save_train_state: state_path = save_training_state(head, optim, generator, args, save_dir, "final_state.pt", args.train_steps, best_acc, best_step) log["final_state_checkpoint"] = state_path log["checkpoints"].append({"kind": "final_state", "train_step": args.train_steps, "acc": acc_f, "path": state_path}) write_log(args.out, log) print(f"log -> {args.out}") if __name__ == "__main__": main()