diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/step9_trajectory_perturb_train.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/step9_trajectory_perturb_train.py')
| -rw-r--r-- | research/flossing/step9_trajectory_perturb_train.py | 595 |
1 files changed, 595 insertions, 0 deletions
diff --git a/research/flossing/step9_trajectory_perturb_train.py b/research/flossing/step9_trajectory_perturb_train.py new file mode 100644 index 0000000..61cbf04 --- /dev/null +++ b/research/flossing/step9_trajectory_perturb_train.py @@ -0,0 +1,595 @@ +"""Step 9: Supervised training with Lyapunov-style initial trajectory perturbations. + +This is hidden-trajectory augmentation, not a Lyapunov/flossing objective: + + single_perturbed_ce: + sample one tiny perturbation of the initial recursive state and train with + the original supervised ACT loss on the same (x, y). + + multi_perturbed_ce: + run one clean trajectory plus K-1 independently perturbed trajectories on + the same (x, y), average the original supervised ACT losses, and update. + +The goal is to enlarge the correct answer basin around the model's nominal +initial latent state without directly optimizing Lyapunov exponents or KL +consistency to the model's current answer. +""" +from __future__ import annotations + +import argparse +import json +import math +import sys +import time +from dataclasses import replace +from pathlib import Path + +import torch +import torch.nn.functional as F + +FLOSS_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(FLOSS_DIR)) + +from step7_interfloss import ( # noqa: E402 + evaluate, + freeze_puzzle_embedding, + load_model, + load_train_batches, + move_batch, + write_log, +) + +IGNORE_LABEL_ID = -100 + + +def _randn_like(tensor: torch.Tensor, generator: torch.Generator) -> torch.Tensor: + noise = torch.randn( + tensor.shape, + device=tensor.device, + dtype=torch.float32, + generator=generator, + ) + return noise.to(tensor.dtype) + + +def _unit_noise_like(tensor: torch.Tensor, generator: torch.Generator, sampling: str) -> torch.Tensor: + if sampling == "uniform": + # Match Normal(0, 1) variance: U[-sqrt(3), sqrt(3)] has std=1. + noise = torch.rand(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) + noise = (2.0 * noise - 1.0) * math.sqrt(3.0) + else: + noise = torch.randn(tensor.shape, device=tensor.device, dtype=torch.float32, generator=generator) + return noise.to(tensor.dtype) + + +def _expand_first_dim(batch: dict[str, torch.Tensor], n: int) -> dict[str, torch.Tensor]: + if n == 1: + return batch + return {k: v.repeat_interleave(n, dim=0) for k, v in batch.items()} + + +def _noise_target_std(args, train_step: int | None) -> float: + if train_step is None or args.sigma_ramp_steps <= 0: + return args.noise_std + start = args.sigma_start if args.sigma_start is not None else args.noise_std + frac = min(max(train_step / args.sigma_ramp_steps, 0.0), 1.0) + return float(start + frac * (args.noise_std - start)) + + +def _sample_noise_stds( + args, + batch_size: int, + n_trajectories: int, + device: str, + generator: torch.Generator, + train_step: int | None, +) -> tuple[torch.Tensor, torch.Tensor, float]: + """Return per-row perturbation stds and a clean/noisy mask for expanded rows.""" + total = batch_size * n_trajectories + std_target = _noise_target_std(args, train_step) + stds = torch.zeros(total, device=device, dtype=torch.float32) + + if args.mode == "baseline_clean": + noisy_mask = torch.zeros(total, device=device, dtype=torch.bool) + return stds, noisy_mask, std_target + + rollout_id = torch.arange(total, device=device) % n_trajectories + noisy_mask = torch.ones(total, device=device, dtype=torch.bool) + if args.mode == "multi_perturbed_ce": + noisy_mask = rollout_id > 0 + + if std_target <= 0: + return stds, noisy_mask, std_target + + active_count = int(noisy_mask.sum().item()) + if active_count == 0: + return stds, noisy_mask, std_target + + if args.noise_sampling == "loguniform": + ramp_scale = 1.0 if args.noise_std <= 0 else std_target / args.noise_std + final_hi = args.noise_max if args.noise_max is not None else args.noise_std + final_lo = args.noise_min if args.noise_min is not None else max(final_hi / 10.0, 1e-8) + lo = float(final_lo) * ramp_scale + hi = float(final_hi) * ramp_scale + if hi <= 0: + return stds, noisy_mask, std_target + lo = max(float(lo), 1e-12) + hi = max(float(hi), lo) + u = torch.rand(active_count, device=device, dtype=torch.float32, generator=generator) + sampled = torch.exp(math.log(lo) + u * (math.log(hi) - math.log(lo))) + else: + sampled = torch.full((active_count,), std_target, device=device, dtype=torch.float32) + + if args.noise_sampling == "mixture_normal" and args.clean_prob > 0: + keep_noisy = torch.rand(active_count, device=device, dtype=torch.float32, generator=generator) >= args.clean_prob + sampled = torch.where(keep_noisy, sampled, torch.zeros_like(sampled)) + + stds[noisy_mask] = sampled + return stds, stds > 0, std_target + + +def _add_state_noise( + inner, + noise_stds: torch.Tensor, + generator: torch.Generator | None, + perturb: str, + sampling: str, +): + if noise_stds.numel() == 0 or float(noise_stds.max().item()) <= 0: + return inner + if generator is None: + raise ValueError("generator is required when noise is active") + + view_shape = (noise_stds.shape[0],) + (1,) * (inner.z_H.ndim - 1) + scaled = noise_stds.view(view_shape) + z_h = inner.z_H + z_l = inner.z_L + if perturb in ("h", "both"): + z_h = z_h + scaled.to(z_h.dtype) * _unit_noise_like(z_h, generator, sampling) + if perturb in ("l", "both"): + z_l = z_l + scaled.to(z_l.dtype) * _unit_noise_like(z_l, generator, sampling) + return replace(inner, z_H=z_h, z_L=z_l) + + +def make_loaded_initial_carry( + base, + batch: dict[str, torch.Tensor], + device: str, + noise_std: float, + generator: torch.Generator | None, + perturb: str, +): + """Construct a first-step carry whose current_data is already loaded. + + The ACT wrappers reset any sample with halted=True before the inner forward. + To perturb the actual initial recurrent state, we first apply the same reset + to H_init/L_init, then mark samples as not halted so the perturbation is not + overwritten on the first model call. + """ + with torch.device(device): + carry = base.initial_carry(batch) + + reset_flag = torch.ones_like(carry.halted) + inner = base.inner.reset_carry(reset_flag, carry.inner_carry) + + if noise_std > 0: + if generator is None: + raise ValueError("generator is required when noise_std > 0") + noise_stds = torch.full((batch["inputs"].shape[0],), noise_std, device=device, dtype=torch.float32) + inner = _add_state_noise(inner, noise_stds, generator, perturb, "normal") + + return replace( + carry, + inner_carry=inner, + steps=torch.zeros_like(carry.steps), + halted=torch.zeros_like(carry.halted), + current_data={k: v for k, v in batch.items()}, + ) + + +def branch_supervised_loss(head, base, batch, carry): + loss_sum = 0.0 + n_loss = 0 + for _ in range(base.config.halt_max_steps): + carry, loss, _metrics, _outputs, all_finish = head(return_keys=[], carry=carry, batch=batch) + loss_sum = loss_sum + loss + n_loss += 1 + if all_finish: + break + return loss_sum / max(n_loss, 1) / batch["inputs"].shape[0] + + +def _token_loss(loss_fn, logits, labels, mask): + kwargs = {"ignore_index": IGNORE_LABEL_ID} + code = getattr(loss_fn, "__code__", None) + arg_names = code.co_varnames[: code.co_argcount + code.co_kwonlyargcount] if code is not None else () + if "valid_mask" in arg_names: + kwargs["valid_mask"] = mask + return loss_fn(logits, labels, **kwargs) + + +def _step_supervised_loss_vec(head, base, inner_carry, batch, act_step: int): + new_inner, logits, (q_halt_logits, q_continue_logits) = base.inner(inner_carry, batch) + labels = batch["labels"] + with torch.no_grad(): + mask = labels != IGNORE_LABEL_ID + loss_counts = mask.sum(-1) + loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) + is_correct = mask & (torch.argmax(logits, dim=-1) == labels) + seq_is_correct = is_correct.sum(-1) == loss_counts + + lm_loss_vec = (_token_loss(head.loss_fn, logits, labels, mask) / loss_divisor).sum(-1) + q_halt_loss_vec = F.binary_cross_entropy_with_logits( + q_halt_logits, + seq_is_correct.to(q_halt_logits.dtype), + reduction="none", + ) + + q_continue_loss_vec = torch.zeros_like(q_halt_loss_vec) + if not getattr(base.config, "no_ACT_continue", False): + with torch.no_grad(): + next_q_halt_logits, next_q_continue_logits = base.inner(new_inner, batch)[-1] + is_last_step = act_step + 1 >= base.config.halt_max_steps + target_q_continue = torch.sigmoid( + torch.where( + torch.full_like(next_q_halt_logits, is_last_step, dtype=torch.bool), + next_q_halt_logits, + torch.maximum(next_q_halt_logits, next_q_continue_logits), + ) + ) + q_continue_loss_vec = F.binary_cross_entropy_with_logits( + q_continue_logits, + target_q_continue, + reduction="none", + ) + + return new_inner, lm_loss_vec + 0.5 * (q_halt_loss_vec + q_continue_loss_vec) + + +def fixed_unroll_supervised_loss(args, head, base, batch, device, generator, train_step: int | None): + batch_size = batch["inputs"].shape[0] + n_trajectories = args.n_trajectories if args.mode == "multi_perturbed_ce" else 1 + expanded_batch = _expand_first_dim(batch, n_trajectories) + total = expanded_batch["inputs"].shape[0] + + noise_stds, actual_noisy_mask, std_target = _sample_noise_stds( + args, + batch_size=batch_size, + n_trajectories=n_trajectories, + device=device, + generator=generator, + train_step=train_step, + ) + + with torch.device(device): + carry = base.initial_carry(expanded_batch) + reset_flag = torch.ones_like(carry.halted) + inner = base.inner.reset_carry(reset_flag, carry.inner_carry) + inner = _add_state_noise(inner, noise_stds, generator, args.perturb, args.noise_sampling) + + loss_vec_sum = torch.zeros(total, device=device, dtype=torch.float32) + inner_carry = inner + for act_step in range(base.config.halt_max_steps): + inner_carry, step_loss_vec = _step_supervised_loss_vec(head, base, inner_carry, expanded_batch, act_step) + loss_vec_sum = loss_vec_sum + step_loss_vec.to(torch.float32) + + per_seq = loss_vec_sum / max(base.config.halt_max_steps, 1) + loss = per_seq.mean() + + if args.mode == "multi_perturbed_ce": + rollout_id = torch.arange(total, device=device) % n_trajectories + clean_mask = rollout_id == 0 + noisy_slot_mask = rollout_id > 0 + elif args.mode == "baseline_clean": + clean_mask = torch.ones(total, device=device, dtype=torch.bool) + noisy_slot_mask = torch.zeros(total, device=device, dtype=torch.bool) + else: + clean_mask = torch.zeros(total, device=device, dtype=torch.bool) + noisy_slot_mask = torch.ones(total, device=device, dtype=torch.bool) + + clean_loss = per_seq.detach()[clean_mask].mean().item() if bool(clean_mask.any()) else 0.0 + noisy_loss = per_seq.detach()[noisy_slot_mask].mean().item() if bool(noisy_slot_mask.any()) else 0.0 + active_noise = noise_stds[actual_noisy_mask] + return loss, { + "clean_loss": float(clean_loss), + "noisy_loss_mean": float(noisy_loss), + "noise_std_target": float(std_target), + "noise_std_mean": float(active_noise.mean().item()) if active_noise.numel() else 0.0, + "noise_std_max": float(active_noise.max().item()) if active_noise.numel() else 0.0, + "effective_batch": int(total), + } + + +def train_loss(args, head, base, batch, device, generator, train_step: int | None): + if args.rollout_impl == "parallel_fixed": + return fixed_unroll_supervised_loss(args, head, base, batch, device, generator, train_step) + + if args.mode == "baseline_clean": + carry = make_loaded_initial_carry( + base, + batch, + device, + noise_std=0.0, + generator=None, + perturb=args.perturb, + ) + loss = branch_supervised_loss(head, base, batch, carry) + return loss, { + "clean_loss": float(loss.detach().item()), + "noisy_loss_mean": 0.0, + "noise_std_target": 0.0, + "noise_std_mean": 0.0, + "noise_std_max": 0.0, + "effective_batch": int(batch["inputs"].shape[0]), + } + + if args.mode == "single_perturbed_ce": + std_target = _noise_target_std(args, train_step) + carry = make_loaded_initial_carry( + base, + batch, + device, + noise_std=std_target, + generator=generator, + perturb=args.perturb, + ) + loss = branch_supervised_loss(head, base, batch, carry) + return loss, { + "clean_loss": 0.0, + "noisy_loss_mean": float(loss.detach().item()), + "noise_std_target": float(std_target), + "noise_std_mean": float(std_target), + "noise_std_max": float(std_target), + "effective_batch": int(batch["inputs"].shape[0]), + } + + if args.mode == "multi_perturbed_ce": + if args.n_trajectories < 2: + raise ValueError("multi_perturbed_ce requires --n-trajectories >= 2") + std_target = _noise_target_std(args, train_step) + + clean_carry = make_loaded_initial_carry( + base, + batch, + device, + noise_std=0.0, + generator=None, + perturb=args.perturb, + ) + clean_loss = branch_supervised_loss(head, base, batch, clean_carry) + losses = [clean_loss] + noisy_vals = [] + for _ in range(args.n_trajectories - 1): + noisy_carry = make_loaded_initial_carry( + base, + batch, + device, + noise_std=std_target, + generator=generator, + perturb=args.perturb, + ) + noisy_loss = branch_supervised_loss(head, base, batch, noisy_carry) + losses.append(noisy_loss) + noisy_vals.append(float(noisy_loss.detach().item())) + + total = torch.stack(losses).mean() + return total, { + "clean_loss": float(clean_loss.detach().item()), + "noisy_loss_mean": sum(noisy_vals) / max(len(noisy_vals), 1), + "noise_std_target": float(std_target), + "noise_std_mean": float(std_target), + "noise_std_max": float(std_target), + "effective_batch": int(batch["inputs"].shape[0] * args.n_trajectories), + } + + raise ValueError(f"unknown mode: {args.mode}") + + +def save_checkpoint(head, save_dir: Path, name: str): + save_dir.mkdir(parents=True, exist_ok=True) + path = save_dir / name + torch.save(head.state_dict(), path) + return str(path) + + +def save_training_state( + head, + optim, + generator: torch.Generator, + args, + save_dir: Path, + name: str, + train_step: int, + best_acc: float, + best_step: int, +): + save_dir.mkdir(parents=True, exist_ok=True) + path = save_dir / name + state = { + "format": "step9_training_state_v1", + "model_state_dict": head.state_dict(), + "optimizer_state_dict": optim.state_dict(), + "train_step": int(train_step), + "best_acc": float(best_acc), + "best_step": int(best_step), + "args": vars(args), + "torch_rng_state": torch.get_rng_state(), + "noise_generator_state": generator.get_state(), + } + if torch.cuda.is_available(): + state["cuda_rng_state"] = torch.cuda.get_rng_state() + torch.save(state, path) + return str(path) + + +def load_training_state(path: Path, head, optim, generator: torch.Generator, device: str): + state = torch.load(path, map_location=device, weights_only=False) + if "model_state_dict" not in state: + raise ValueError(f"{path} is not a step9 training-state checkpoint") + missing, unexpected = head.load_state_dict(state["model_state_dict"], strict=False) + print(f"[resume {path}] missing={len(missing)} unexpected={len(unexpected)}") + if "optimizer_state_dict" in state: + optim.load_state_dict(state["optimizer_state_dict"]) + if "torch_rng_state" in state: + torch.set_rng_state(state["torch_rng_state"].cpu()) + if "cuda_rng_state" in state and torch.cuda.is_available(): + torch.cuda.set_rng_state(state["cuda_rng_state"].cpu()) + if "noise_generator_state" in state: + generator.set_state(state["noise_generator_state"].cpu()) + return { + "train_step": int(state.get("train_step", 0)), + "best_acc": state.get("best_acc"), + "best_step": int(state.get("best_step", 0)), + } + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", choices=["hrm", "trm"], required=True) + parser.add_argument("--ckpt-root", required=True) + parser.add_argument("--ckpt-name", required=True) + parser.add_argument( + "--mode", + choices=["baseline_clean", "single_perturbed_ce", "multi_perturbed_ce"], + required=True, + ) + parser.add_argument("--train-steps", type=int, default=10000) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--lr", type=float, default=1e-5) + parser.add_argument("--noise-std", type=float, default=1e-3) + parser.add_argument("--noise-min", type=float, default=None) + parser.add_argument("--noise-max", type=float, default=None) + parser.add_argument( + "--noise-sampling", + choices=["normal", "uniform", "loguniform", "mixture_normal"], + default="normal", + ) + parser.add_argument("--clean-prob", type=float, default=0.0) + parser.add_argument("--sigma-start", type=float, default=None) + parser.add_argument("--sigma-ramp-steps", type=int, default=0) + parser.add_argument("--n-trajectories", type=int, default=4) + parser.add_argument("--rollout-impl", choices=["serial_act", "parallel_fixed"], default="parallel_fixed") + parser.add_argument("--perturb", choices=["h", "l", "both"], default="both") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--eval-every", type=int, default=1000) + parser.add_argument("--eval-n", type=int, default=512) + parser.add_argument("--eval-batch-size", type=int, default=32) + parser.add_argument("--out", default="step9_trajectory_perturb_log.json") + parser.add_argument("--save-dir", default=None) + parser.add_argument("--save-best", action="store_true") + parser.add_argument("--save-final", action="store_true") + parser.add_argument("--save-every-eval", action="store_true") + parser.add_argument("--save-train-state", action="store_true") + parser.add_argument("--resume-state", default=None) + args = parser.parse_args() + + torch.manual_seed(args.seed) + device = "cuda" + head, base, cfg, adam_cls = load_model(args.model, Path(args.ckpt_root), args.ckpt_name, device) + data_path = Path(cfg["data_path"]) + optim = adam_cls(head.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=cfg["weight_decay"]) + generator = torch.Generator(device=device).manual_seed(args.seed + 900000) + resume_step = 0 + resume_best_acc = None + resume_best_step = 0 + if args.resume_state: + resume_info = load_training_state(Path(args.resume_state), head, optim, generator, device) + resume_step = resume_info["train_step"] + resume_best_acc = resume_info["best_acc"] + resume_best_step = resume_info["best_step"] + print(f"[resume] train_step={resume_step} best_acc={resume_best_acc} best_step={resume_best_step}", flush=True) + + print(f"\n=== Initial eval (loaded {args.ckpt_name}) ===") + acc0, tok0 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print(f" initial: exact_acc={acc0:.4f} token_acc={tok0:.4f}", flush=True) + + log = { + "args": vars(args), + "initial_acc": acc0, + "initial_tok_acc": tok0, + "steps": [], + "evals": [{"kind": "initial", "train_step": 0, "acc": acc0, "tok_acc": tok0}], + "checkpoints": [], + "resume_state": args.resume_state, + "resume_step": resume_step, + } + write_log(args.out, log) + save_dir = Path(args.save_dir) if args.save_dir else Path(str(args.out)).with_suffix("").with_name(Path(str(args.out)).with_suffix("").name + "_ckpts") + best_acc = float(resume_best_acc) if resume_best_acc is not None else acc0 + best_step = resume_best_step if resume_best_acc is not None else 0 + + train_iter = load_train_batches(data_path, args.batch_size, args.train_steps, seed=args.seed) + for _ in range(min(resume_step, args.train_steps)): + next(train_iter) + t0 = time.time() + for train_step, batch_cpu in enumerate(train_iter, start=resume_step): + batch = move_batch(batch_cpu, device) + head.train() + freeze_puzzle_embedding(base) + optim.zero_grad(set_to_none=True) + loss, parts = train_loss(args, head, base, batch, device, generator, train_step + 1) + loss.backward() + torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0) + optim.step() + + rec = { + "train_step": train_step + 1, + "loss": float(loss.detach().item()), + **parts, + } + log["steps"].append(rec) + if train_step % 50 == 0 or train_step == args.train_steps - 1: + print( + f" T[{train_step + 1:>5}/{args.train_steps}] dt={time.time() - t0:.1f}s " + f"loss={rec['loss']:.4f} clean={rec['clean_loss']:.4f} " + f"noisy={rec['noisy_loss_mean']:.4f} " + f"sigma={rec['noise_std_mean']:.2e}/{rec['noise_std_max']:.2e} " + f"effB={rec['effective_batch']}", + flush=True, + ) + + if (train_step + 1) % args.eval_every == 0: + acc, tok_acc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print(f" >> EVAL @ step {train_step + 1}: exact_acc={acc:.4f} delta={acc - acc0:+.4f}", flush=True) + log["evals"].append({"kind": "task", "train_step": train_step + 1, "acc": acc, "tok_acc": tok_acc}) + if args.save_every_eval: + ckpt_path = save_checkpoint(head, save_dir, f"step_{train_step + 1}.pt") + log["checkpoints"].append({"kind": "eval", "train_step": train_step + 1, "acc": acc, "path": ckpt_path}) + if args.save_best and acc >= best_acc: + best_acc = acc + best_step = train_step + 1 + ckpt_path = save_checkpoint(head, save_dir, "best.pt") + log["best_acc"] = best_acc + log["best_step"] = best_step + log["best_checkpoint"] = ckpt_path + log["checkpoints"].append({"kind": "best", "train_step": train_step + 1, "acc": acc, "path": ckpt_path}) + if args.save_train_state: + state_path = save_training_state(head, optim, generator, args, save_dir, "best_state.pt", train_step + 1, best_acc, best_step) + log["best_state_checkpoint"] = state_path + log["checkpoints"].append({"kind": "best_state", "train_step": train_step + 1, "acc": acc, "path": state_path}) + if args.save_train_state: + state_path = save_training_state(head, optim, generator, args, save_dir, "latest_state.pt", train_step + 1, best_acc, best_step) + log["latest_state_checkpoint"] = state_path + log["checkpoints"].append({"kind": "latest_state", "train_step": train_step + 1, "acc": acc, "path": state_path}) + write_log(args.out, log) + + acc_f, tok_f = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print("\n=== Final eval ===") + print(f" initial={acc0:.4f} final={acc_f:.4f} delta={acc_f - acc0:+.4f}", flush=True) + log["final_acc"] = acc_f + log["final_tok_acc"] = tok_f + log["evals"].append({"kind": "final", "train_step": args.train_steps, "acc": acc_f, "tok_acc": tok_f}) + if args.save_final: + ckpt_path = save_checkpoint(head, save_dir, "final.pt") + log["final_checkpoint"] = ckpt_path + log["checkpoints"].append({"kind": "final", "train_step": args.train_steps, "acc": acc_f, "path": ckpt_path}) + if args.save_train_state: + state_path = save_training_state(head, optim, generator, args, save_dir, "final_state.pt", args.train_steps, best_acc, best_step) + log["final_state_checkpoint"] = state_path + log["checkpoints"].append({"kind": "final_state", "train_step": args.train_steps, "acc": acc_f, "path": state_path}) + write_log(args.out, log) + print(f"log -> {args.out}") + + +if __name__ == "__main__": + main() |
