summaryrefslogtreecommitdiff
path: root/research/flossing/step7_interfloss.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/step7_interfloss.py')
-rw-r--r--research/flossing/step7_interfloss.py589
1 files changed, 589 insertions, 0 deletions
diff --git a/research/flossing/step7_interfloss.py b/research/flossing/step7_interfloss.py
new file mode 100644
index 0000000..3b8e4f0
--- /dev/null
+++ b/research/flossing/step7_interfloss.py
@@ -0,0 +1,589 @@
+"""Step 7: Engelken-style interflossing.
+
+This is intentionally not a mixed objective. Ordinary task-training steps use
+only the supervised ACT loss. Flossing episodes use only a Lyapunov-spectrum
+conditioning loss, then task training resumes.
+
+Paper mapping:
+ - preflossing: run a floss-only episode before task training.
+ - interflossing: run short floss-only episodes at selected training steps.
+ - no persistent L_task + alpha * L_floss term is used here.
+"""
+from __future__ import annotations
+
+import argparse
+import importlib
+import json
+import sys
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import yaml
+
+
+HRM_DIR = Path("/home/yurenh2/rrm/hrm")
+TRM_DIR = Path("/home/yurenh2/rrm/trm")
+
+
+def import_stack(model_type: str):
+ repo_dir = HRM_DIR if model_type == "hrm" else TRM_DIR
+ sys.path.insert(0, str(repo_dir))
+ if model_type == "hrm":
+ model_mod = importlib.import_module("models.hrm.hrm_act_v1")
+ model_cls = model_mod.HierarchicalReasoningModel_ACTV1
+ else:
+ model_mod = importlib.import_module("models.recursive_reasoning.trm")
+ model_cls = model_mod.TinyRecursiveReasoningModel_ACTV1
+ losses_mod = importlib.import_module("models.losses")
+ optim_mod = importlib.import_module("adam_atan2")
+ sparse_mod = importlib.import_module("models.sparse_embedding")
+ return model_cls, losses_mod.ACTLossHead, optim_mod.AdamATan2, sparse_mod.CastedSparseEmbeddingSignSGD_Distributed
+
+
+def parse_step_list(text: str) -> set[int]:
+ if not text.strip():
+ return set()
+ out = set()
+ for part in text.split(","):
+ part = part.strip()
+ if not part:
+ continue
+ out.add(int(part))
+ return out
+
+
+def build_interfloss_steps(args) -> set[int]:
+ steps = parse_step_list(args.interfloss_at)
+ if args.interfloss_every and args.interfloss_every > 0:
+ start = max(args.interfloss_start, 0)
+ stop = args.interfloss_stop if args.interfloss_stop >= 0 else args.train_steps
+ stop = min(stop, args.train_steps)
+ steps.update(range(start, stop + 1, args.interfloss_every))
+ return steps
+
+
+def load_model(model_type: str, ckpt_root: Path, ckpt_name: str, device: str, batch_size_override: int | None = None):
+ model_cls, loss_head_cls, adam_cls, sparse_cls = import_stack(model_type)
+ cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text())
+ arch_cfg = dict(cfg["arch"])
+ data_path = Path(cfg.get("data_path") or cfg["data_paths"][0])
+ train_meta = json.loads((data_path / "train" / "dataset.json").read_text())
+ arch_cfg.update(
+ batch_size=batch_size_override or cfg["global_batch_size"],
+ seq_len=train_meta["seq_len"],
+ vocab_size=train_meta["vocab_size"],
+ num_puzzle_identifiers=train_meta["num_puzzle_identifiers"],
+ causal=False,
+ )
+ cfg["data_path"] = str(data_path)
+ with torch.device(device):
+ base = model_cls(arch_cfg)
+ head = loss_head_cls(base, loss_type=arch_cfg["loss"]["loss_type"])
+ if ckpt_name != "__random__":
+ sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True)
+ stripped = {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
+ missing, unexpected = head.load_state_dict(stripped, strict=False)
+ print(f"[load {ckpt_name}] missing={len(missing)} unexpected={len(unexpected)}")
+ else:
+ print("[load __random__] random initialization from config")
+ return head, base, cfg, adam_cls, sparse_cls
+
+
+def jvp_train(f, x, v):
+ return torch.autograd.functional.jvp(f, x, v=v, create_graph=True, strict=False)
+
+
+def compute_joint_lyap_spec(model_type, base, batch, k_lyap, lyap_act_steps, device, seed, lyap_start_act=0):
+ inner = base.inner
+ cfg = inner.config
+ bsz = batch["inputs"].shape[0]
+ seq_full = cfg.seq_len + inner.puzzle_emb_len
+ hidden = cfg.hidden_size
+ dim = seq_full * hidden
+
+ z_h = inner.H_init.unsqueeze(0).expand(bsz, seq_full, hidden).clone().to(inner.forward_dtype)
+ z_l = inner.L_init.unsqueeze(0).expand(bsz, seq_full, hidden).clone().to(inner.forward_dtype)
+ seq_info = {"cos_sin": inner.rotary_emb() if hasattr(inner, "rotary_emb") else None}
+ input_embeddings = inner._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
+
+ # Optional late-window measurement: first move to a later recursive state
+ # without differentiating through the warmup trajectory. This regularizes
+ # local late-stage stability instead of penalizing useful early expansion.
+ warmup_acts = min(max(lyap_start_act, 0), cfg.halt_max_steps)
+ if warmup_acts > 0:
+ with torch.no_grad():
+ for _act in range(warmup_acts):
+ for _h in range(cfg.H_cycles):
+ for _l in range(cfg.L_cycles):
+ z_l = inner.L_level(z_l, z_h + input_embeddings, **seq_info)
+ if model_type == "trm":
+ z_h = inner.L_level(z_h, z_l, **seq_info)
+ else:
+ z_h = inner.H_level(z_h, z_l, **seq_info)
+ z_h = z_h.detach()
+ z_l = z_l.detach()
+
+ gen = torch.Generator(device=device).manual_seed(seed)
+ q0 = torch.randn(bsz, 2 * dim, k_lyap, device=device, dtype=torch.float32, generator=gen)
+ q, _ = torch.linalg.qr(q0)
+ log_r_sum = torch.zeros(bsz, k_lyap, device=device, dtype=torch.float32)
+ n_steps = 0
+
+ n_act = min(lyap_act_steps, max(cfg.halt_max_steps - warmup_acts, 1))
+ for _act in range(n_act):
+ for _h in range(cfg.H_cycles):
+ for _l in range(cfg.L_cycles):
+ v_h = q[:, :dim, :]
+ v_l = q[:, dim:, :]
+ v_comb = v_h + v_l
+ new_v_l_cols = []
+ f_l = lambda z: inner.L_level(z, z_h + input_embeddings, **seq_info)
+ for i in range(k_lyap):
+ v_i = v_comb[:, :, i].reshape(bsz, seq_full, hidden).to(inner.forward_dtype)
+ z_l_new, d_v = jvp_train(f_l, z_l, v_i)
+ new_v_l_cols.append(d_v.reshape(bsz, dim).to(torch.float32))
+ q = torch.cat([v_h, torch.stack(new_v_l_cols, dim=-1)], dim=1)
+ z_l = z_l_new
+ q, r = torch.linalg.qr(q)
+ log_r_sum = log_r_sum + r.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log()
+ n_steps += 1
+
+ v_h = q[:, :dim, :]
+ v_l = q[:, dim:, :]
+ v_comb = v_h + v_l
+ new_v_h_cols = []
+ if model_type == "trm":
+ f_h = lambda z: inner.L_level(z, z_l, **seq_info)
+ else:
+ f_h = lambda z: inner.H_level(z, z_l, **seq_info)
+ for i in range(k_lyap):
+ v_i = v_comb[:, :, i].reshape(bsz, seq_full, hidden).to(inner.forward_dtype)
+ z_h_new, d_v = jvp_train(f_h, z_h, v_i)
+ new_v_h_cols.append(d_v.reshape(bsz, dim).to(torch.float32))
+ q = torch.cat([torch.stack(new_v_h_cols, dim=-1), v_l], dim=1)
+ z_h = z_h_new
+ q, r = torch.linalg.qr(q)
+ log_r_sum = log_r_sum + r.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log()
+ n_steps += 1
+
+ return log_r_sum / max(n_steps, 1)
+
+
+def floss_loss_from_spec(spec, mode: str, lambda_star: float):
+ if mode == "engelken_l2":
+ return (spec ** 2).mean(), spec
+ if mode == "spectrum_cf":
+ excess = (spec - lambda_star).clamp_min(0.0)
+ return (excess ** 2).mean(), excess
+ if mode == "volume_cf":
+ volume = spec.mean(dim=1)
+ excess = (volume - lambda_star).clamp_min(0.0)
+ return (excess ** 2).mean(), excess
+ if mode == "top1_cf":
+ excess = (spec[:, 0] - lambda_star).clamp_min(0.0)
+ return (excess ** 2).mean(), excess
+ raise ValueError(f"unknown floss mode: {mode}")
+
+
+def load_train_batches(data_path: Path, batch_size: int, n_iters: int, seed: int = 0):
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "train" / "all__inputs.npy")
+ labels = np.load(data_path / "train" / "all__labels.npy")
+ pid = np.load(data_path / "train" / "all__puzzle_identifiers.npy")
+ n = len(inputs)
+ for _ in range(n_iters):
+ idx = rng.choice(n, size=batch_size, replace=False)
+ yield {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)),
+ "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)),
+ }
+
+
+def sample_replay_batch(data_path: Path, n_samples: int, seed: int):
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "train" / "all__inputs.npy")
+ labels = np.load(data_path / "train" / "all__labels.npy")
+ pid = np.load(data_path / "train" / "all__puzzle_identifiers.npy")
+ idx = rng.choice(len(inputs), size=n_samples, replace=False)
+ return {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)),
+ "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)),
+ }
+
+
+def move_batch(batch: dict[str, torch.Tensor], device: str):
+ return {k: v.to(device) for k, v in batch.items()}
+
+
+def rollout_logits(base, batch, device):
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+ for _ in range(base.config.halt_max_steps):
+ carry, outputs = base(carry=carry, batch=batch)
+ return outputs["logits"]
+
+
+def build_kl_replay(args, base, data_path, device, episode_idx):
+ if args.kl_beta <= 0 or args.kl_replay_size <= 0:
+ return None
+
+ replay = sample_replay_batch(
+ data_path,
+ n_samples=args.kl_replay_size,
+ seed=args.seed + 200000 + episode_idx,
+ )
+ teacher_chunks = []
+ base.eval()
+ with torch.no_grad():
+ for start in range(0, args.kl_replay_size, args.kl_batch_size):
+ end = min(start + args.kl_batch_size, args.kl_replay_size)
+ batch = move_batch({k: v[start:end] for k, v in replay.items()}, device)
+ logits = rollout_logits(base, batch, device)
+ teacher_chunks.append(logits.detach().to(torch.float32).cpu())
+
+ replay["teacher_logits"] = torch.cat(teacher_chunks, dim=0)
+ replay["mask"] = replay["labels"] > 0
+ return replay
+
+
+def kl_preservation_loss(args, base, replay, step, device):
+ if replay is None:
+ return torch.zeros((), device=device)
+
+ n_replay = replay["inputs"].shape[0]
+ batch_size = min(args.kl_batch_size, n_replay)
+ start = (step * batch_size) % n_replay
+ if start + batch_size <= n_replay:
+ idx = torch.arange(start, start + batch_size)
+ else:
+ idx = torch.cat([torch.arange(start, n_replay), torch.arange(0, start + batch_size - n_replay)])
+
+ batch = move_batch(
+ {
+ "inputs": replay["inputs"][idx],
+ "labels": replay["labels"][idx],
+ "puzzle_identifiers": replay["puzzle_identifiers"][idx],
+ },
+ device,
+ )
+ teacher_logits = replay["teacher_logits"][idx].to(device)
+ mask = replay["mask"][idx].to(device)
+ was_training = base.training
+ base.eval()
+ student_logits = rollout_logits(base, batch, device).to(torch.float32)
+ if was_training:
+ base.train()
+ set_puzzle_embedding_mode(base, args.train_puzzle_emb)
+ temp = args.kl_temperature
+ student_logp = F.log_softmax(student_logits / temp, dim=-1)
+ teacher_p = F.softmax(teacher_logits / temp, dim=-1)
+ kl_per_token = F.kl_div(student_logp, teacher_p, reduction="none").sum(dim=-1) * (temp ** 2)
+ if mask.any():
+ return kl_per_token[mask].mean()
+ return kl_per_token.mean()
+
+
+def evaluate(head, base, data_path, n_samples, batch_size, device, seed=42):
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "test" / "all__inputs.npy")
+ labels = np.load(data_path / "test" / "all__labels.npy")
+ pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy")
+ idx_all = rng.choice(len(inputs), size=n_samples, replace=False)
+ head.eval()
+ correct = 0
+ token_correct = 0
+ token_total = 0
+ for start in range(0, n_samples, batch_size):
+ end = min(start + batch_size, n_samples)
+ idx = idx_all[start:end]
+ batch = {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)).to(device),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)).to(device),
+ "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)).to(device),
+ }
+ with torch.no_grad():
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+ for _ in range(base.config.halt_max_steps):
+ carry, outputs = base(carry=carry, batch=batch)
+ preds = outputs["logits"].argmax(dim=-1)
+ mask = batch["labels"] > 0
+ exact = ((preds == batch["labels"]) | ~mask).all(dim=-1).float()
+ correct += exact.sum().item()
+ token_correct += ((preds == batch["labels"]) & mask).sum().item()
+ token_total += mask.sum().item()
+ return correct / n_samples, token_correct / max(token_total, 1)
+
+
+def write_log(path: str, log: dict):
+ Path(path).write_text(json.dumps(log, indent=2))
+
+
+def freeze_puzzle_embedding(base):
+ base.inner.puzzle_emb.eval()
+
+
+def set_puzzle_embedding_mode(base, train_puzzle_emb: bool):
+ if train_puzzle_emb:
+ base.inner.puzzle_emb.train()
+ else:
+ freeze_puzzle_embedding(base)
+
+
+def make_optimizers(args, base, head, adam_cls, sparse_cls, lr: float, weight_decay: float, train_puzzle_emb: bool):
+ optimizers = []
+ if train_puzzle_emb and getattr(base.inner.config, "puzzle_emb_ndim", 0) > 0:
+ optimizers.append(
+ sparse_cls(
+ base.inner.puzzle_emb.buffers(),
+ lr=lr if args.puzzle_emb_lr is None else args.puzzle_emb_lr,
+ weight_decay=args.puzzle_emb_weight_decay,
+ world_size=1,
+ )
+ )
+ optimizers.append(adam_cls(head.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=weight_decay))
+ return optimizers
+
+
+def optim_zero_grad(optimizers):
+ for optim in optimizers:
+ optim.zero_grad(set_to_none=True)
+
+
+def optim_step(optimizers):
+ for optim in optimizers:
+ optim.step()
+
+
+def run_floss_episode(args, head, base, adam_cls, data_path, device, log, episode_idx, train_step):
+ print(
+ f"\n=== Floss episode {episode_idx} at train_step={train_step}: "
+ f"{args.floss_steps} steps, mode={args.floss_mode}, lr={args.floss_lr} ===",
+ flush=True,
+ )
+ optimizers = make_optimizers(
+ args, base, head, adam_cls, args.sparse_cls,
+ lr=args.floss_lr, weight_decay=0.0, train_puzzle_emb=False,
+ )
+ replay = build_kl_replay(args, base, data_path, device, episode_idx)
+ train_iter = load_train_batches(
+ data_path,
+ args.floss_batch_size,
+ args.floss_steps,
+ seed=args.seed + 100000 + episode_idx * 1000,
+ )
+ episode = {"episode": episode_idx, "train_step": train_step, "steps": []}
+ t0 = time.time()
+
+ for step, batch in enumerate(train_iter):
+ batch = {k: v.to(device) for k, v in batch.items()}
+ head.train()
+ set_puzzle_embedding_mode(base, False)
+ spec = compute_joint_lyap_spec(
+ args.model,
+ base,
+ batch,
+ k_lyap=args.k_lyap,
+ lyap_act_steps=args.lyap_act_steps,
+ device=device,
+ seed=args.seed + episode_idx * 10000 + step,
+ lyap_start_act=args.lyap_start_act,
+ )
+ optim_zero_grad(optimizers)
+ floss_loss, excess = floss_loss_from_spec(spec, args.floss_mode, args.lambda_star)
+ floss_loss.backward()
+ kl_loss = kl_preservation_loss(args, base, replay, step, device)
+ if args.kl_beta > 0:
+ (args.kl_beta * kl_loss).backward()
+ torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0)
+ optim_step(optimizers)
+
+ detached = spec.detach()
+ total_loss = floss_loss.detach() + args.kl_beta * kl_loss.detach()
+ rec = {
+ "step": step,
+ "loss": float(total_loss.item()),
+ "floss_loss": float(floss_loss.item()),
+ "kl_loss": float(kl_loss.item()),
+ "lyap1_mean": float(detached[:, 0].mean().item()),
+ "lyap1_max": float(detached[:, 0].max().item()),
+ "lyap_mean": float(detached.mean().item()),
+ "volume_mean": float(detached.mean(dim=1).mean().item()),
+ "volume_max": float(detached.mean(dim=1).max().item()),
+ "frac_active": float((excess.detach() > 0).float().mean().item()),
+ }
+ episode["steps"].append(rec)
+ if step % args.floss_log_every == 0 or step == args.floss_steps - 1:
+ print(
+ f" F[{step:>4}/{args.floss_steps}] dt={time.time() - t0:.1f}s "
+ f"loss={rec['loss']:.6f} floss={rec['floss_loss']:.6f} "
+ f"kl={rec['kl_loss']:.6f} lyap1={rec['lyap1_mean']:+.4f} "
+ f"vol={rec['volume_mean']:+.4f} active={rec['frac_active']:.2f}",
+ flush=True,
+ )
+
+ log["floss_episodes"].append(episode)
+ if args.eval_after_floss:
+ acc, tok_acc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f" >> FLOSS EVAL train_step={train_step}: exact_acc={acc:.4f}", flush=True)
+ log["evals"].append(
+ {"kind": "after_floss", "train_step": train_step, "episode": episode_idx, "acc": acc, "tok_acc": tok_acc}
+ )
+ write_log(args.out, log)
+
+
+def run_task_step(args, head, base, batch, optimizers, device):
+ batch = {k: v.to(device) for k, v in batch.items()}
+ head.train()
+ set_puzzle_embedding_mode(base, args.train_puzzle_emb)
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+ loss_sum = 0.0
+ n_loss = 0
+ for _ in range(base.config.halt_max_steps):
+ carry, loss, _metrics, _outputs, all_finish = head(return_keys=[], carry=carry, batch=batch)
+ loss_sum = loss_sum + loss
+ n_loss += 1
+ if all_finish:
+ break
+ sup_loss = loss_sum / max(n_loss, 1) / batch["inputs"].shape[0]
+ optim_zero_grad(optimizers)
+ sup_loss.backward()
+ torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0)
+ optim_step(optimizers)
+ return sup_loss
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", choices=["hrm", "trm"], required=True)
+ parser.add_argument("--ckpt-root", required=True)
+ parser.add_argument("--ckpt-name", required=True,
+ help="Checkpoint file name, or __random__ to initialize from config without loading weights.")
+ parser.add_argument("--train-steps", type=int, default=10000)
+ parser.add_argument("--batch-size", type=int, default=8)
+ parser.add_argument("--task-batch-size", type=int, default=None,
+ help="Supervised task microbatch size. Defaults to --batch-size.")
+ parser.add_argument("--floss-batch-size", type=int, default=None,
+ help="Flossing microbatch size. Defaults to --batch-size.")
+ parser.add_argument("--train-lr", type=float, default=1e-5)
+ parser.add_argument("--floss-lr", type=float, default=1e-4)
+ parser.add_argument("--floss-steps", type=int, default=500)
+ parser.add_argument("--interfloss-at", default="0,500")
+ parser.add_argument("--interfloss-every", type=int, default=0,
+ help="If >0, also run floss episodes periodically every N task optimizer steps.")
+ parser.add_argument("--interfloss-start", type=int, default=0,
+ help="First task optimizer step for periodic interfloss.")
+ parser.add_argument("--interfloss-stop", type=int, default=-1,
+ help="Last task optimizer step for periodic interfloss. -1 means train_steps.")
+ parser.add_argument("--floss-mode", choices=["engelken_l2", "spectrum_cf", "volume_cf", "top1_cf"], default="engelken_l2")
+ parser.add_argument("--lambda-star", type=float, default=0.0)
+ parser.add_argument("--k-lyap", type=int, default=8)
+ parser.add_argument("--lyap-act-steps", type=int, default=4)
+ parser.add_argument("--lyap-start-act", type=int, default=0,
+ help="Warm up this many ACT steps before measuring/flossing the Lyapunov window.")
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--eval-every", type=int, default=1000)
+ parser.add_argument("--eval-n", type=int, default=512)
+ parser.add_argument("--eval-batch-size", type=int, default=32)
+ parser.add_argument("--floss-log-every", type=int, default=10)
+ parser.add_argument("--eval-after-floss", action=argparse.BooleanOptionalAction, default=True)
+ parser.add_argument("--kl-beta", type=float, default=0.0,
+ help="Episode-start replay-logit KL weight during floss-only steps.")
+ parser.add_argument("--kl-replay-size", type=int, default=64)
+ parser.add_argument("--kl-batch-size", type=int, default=8)
+ parser.add_argument("--kl-temperature", type=float, default=1.0)
+ parser.add_argument("--init-seed", type=int, default=None,
+ help="Torch seed used before model construction. Use this for matched from-scratch runs.")
+ parser.add_argument("--train-puzzle-emb", action=argparse.BooleanOptionalAction, default=False,
+ help="Train sparse puzzle embeddings. Requires --batch-size to match the model local embedding batch.")
+ parser.add_argument("--puzzle-emb-lr", type=float, default=None,
+ help="Sparse puzzle embedding LR. Defaults to current phase LR.")
+ parser.add_argument("--puzzle-emb-weight-decay", type=float, default=1.0)
+ parser.add_argument("--out", default="step7_interfloss_log.json")
+ args = parser.parse_args()
+ if args.task_batch_size is None:
+ args.task_batch_size = args.batch_size
+ if args.floss_batch_size is None:
+ args.floss_batch_size = args.batch_size
+ args.batch_size = args.task_batch_size
+
+ device = "cuda"
+ if args.init_seed is not None:
+ torch.manual_seed(args.init_seed)
+ np.random.seed(args.init_seed)
+ interfloss_steps = build_interfloss_steps(args)
+ head, base, cfg, adam_cls, sparse_cls = load_model(
+ args.model,
+ Path(args.ckpt_root),
+ args.ckpt_name,
+ device,
+ batch_size_override=args.task_batch_size if args.train_puzzle_emb else None,
+ )
+ args.sparse_cls = sparse_cls
+ data_path = Path(cfg["data_path"])
+
+ print(f"\n=== Initial eval (loaded {args.ckpt_name}) ===")
+ acc0, tok0 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f" initial: exact_acc={acc0:.4f} token_acc={tok0:.4f}", flush=True)
+
+ log = {
+ "args": {k: v for k, v in vars(args).items() if k != "sparse_cls"},
+ "initial_acc": acc0,
+ "initial_tok_acc": tok0,
+ "interfloss_steps": sorted(interfloss_steps),
+ "task_steps": [],
+ "floss_episodes": [],
+ "evals": [{"kind": "initial", "train_step": 0, "acc": acc0, "tok_acc": tok0}],
+ }
+ write_log(args.out, log)
+
+ task_optimizers = make_optimizers(
+ args, base, head, adam_cls, sparse_cls,
+ lr=args.train_lr, weight_decay=cfg["weight_decay"], train_puzzle_emb=args.train_puzzle_emb,
+ )
+ train_iter = load_train_batches(data_path, args.task_batch_size, args.train_steps, seed=args.seed)
+ episode_idx = 0
+ t0 = time.time()
+
+ for train_step, batch in enumerate(train_iter):
+ if train_step in interfloss_steps:
+ run_floss_episode(args, head, base, adam_cls, data_path, device, log, episode_idx, train_step)
+ episode_idx += 1
+
+ sup_loss = run_task_step(args, head, base, batch, task_optimizers, device)
+ rec = {"train_step": train_step + 1, "sup_loss": float(sup_loss.item())}
+ log["task_steps"].append(rec)
+ if train_step % 50 == 0 or train_step == args.train_steps - 1:
+ print(
+ f" T[{train_step + 1:>5}/{args.train_steps}] dt={time.time() - t0:.1f}s "
+ f"sup={rec['sup_loss']:.4f}",
+ flush=True,
+ )
+
+ if (train_step + 1) % args.eval_every == 0:
+ acc, tok_acc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f" >> TASK EVAL @ step {train_step + 1}: exact_acc={acc:.4f} delta={acc - acc0:+.4f}", flush=True)
+ log["evals"].append({"kind": "task", "train_step": train_step + 1, "acc": acc, "tok_acc": tok_acc})
+ write_log(args.out, log)
+
+ if args.train_steps in interfloss_steps:
+ run_floss_episode(args, head, base, adam_cls, data_path, device, log, episode_idx, args.train_steps)
+
+ acc_f, tok_f = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print("\n=== Final eval ===")
+ print(f" initial={acc0:.4f} final={acc_f:.4f} delta={acc_f - acc0:+.4f}", flush=True)
+ log["final_acc"] = acc_f
+ log["final_tok_acc"] = tok_f
+ log["evals"].append({"kind": "final", "train_step": args.train_steps, "acc": acc_f, "tok_acc": tok_f})
+ write_log(args.out, log)
+ print(f"log -> {args.out}")
+
+
+if __name__ == "__main__":
+ main()