diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/maze_package | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/maze_package')
9 files changed, 981 insertions, 0 deletions
diff --git a/research/flossing/maze_package/README.md b/research/flossing/maze_package/README.md new file mode 100644 index 0000000..45ab740 --- /dev/null +++ b/research/flossing/maze_package/README.md @@ -0,0 +1,30 @@ +# Maze-Hard package (E8) — train on dedicated cards, diagnose after + +## Contents +- `launch_maze_trm.sh` — TRM Maze official recipe (att variant, 50k epochs), 1–2 GPU. +- dataset already at `/home/yurenh2/rrm/data/maze-30x30-hard-1k` (built 2026-06-13; + seq_len 900, vocab 6, 1000 puzzles ×8 dihedral augments). + +## Run +```bash +bash launch_maze_trm.sh 2 384 # 2x A6000 +bash launch_maze_trm.sh 2 192 # 2x A5000 (->128 if OOM) +``` +Target: ~75% exact accuracy (official figure). Saves a checkpoint every 5000 epochs +(10 checkpoints) — needed for the evolution analysis. + +## After training: diagnostics +The 2x2 / FTLE pipeline reads any TRM checkpoint dir (all_config.yaml + step_N). Two caveats +vs Sudoku, to verify on first run: +1. ATTENTION arch (not mlp_t): confirm diagnose_trm_joint.py's JVP path runs on att blocks + (Sudoku used mlp_t). If the L_level call signature differs, patch the f_L/f_H closures. +2. seq_len 900 vs 97 → per-sample JVP+QR cost ~9-10x Sudoku. Use n=512 for the headline 2x2 + and n=256 for the horizon sweep; k_lyap=8 unchanged. Budget ~0.5-1 day on one card, or + rsync checkpoints back to the lab box and run via the analysis_2x2 queue. + +## What Maze closes +Kills the "Sudoku-only" limitation. Pre-registered prediction (write BEFORE looking, for the +paper's credibility): if the wandering-not-settling decomposition is architecture/task-general, +Maze should show B≈0 (failures don't settle) and the same concurrent-not-antecedent horizon +profile. A DIFFERENT result (e.g. Maze failures do settle) is also publishable — it bounds the +claim's scope. Either way the decomposition gets a second task. diff --git a/research/flossing/maze_package/TRANSFER_README.md b/research/flossing/maze_package/TRANSFER_README.md new file mode 100644 index 0000000..6e11076 --- /dev/null +++ b/research/flossing/maze_package/TRANSFER_README.md @@ -0,0 +1,36 @@ +# Maze training bundle — transfer to your training machine + +## What's in this bundle +- `maze-30x30-hard-1k/` — the built dataset (seq_len 900, vocab 6, 1000 puzzles ×8 augments). +- `launch_maze_trm_portable.sh` — path-configurable launcher. +- `diagnose_trm_joint.py`, `step7_interfloss.py` — diagnostic scripts (only if you run + diagnostics on the training machine; otherwise rsync checkpoints back to the lab box). + +## On the training machine +1. Have the TinyRecursiveModels repo cloned and the `rrm` conda env (torch 2.7 cu126, + flash-attn 2 for Ampere). If the env doesn't exist, recreate from the lab box's + `env/requirements.txt` / `pip-freeze.txt`. +2. Put the dataset somewhere, e.g. `~/data/maze-30x30-hard-1k`. +3. Launch: + ```bash + TRM_DIR=~/TinyRecursiveModels DATA_DIR=~/data/maze-30x30-hard-1k \ + bash launch_maze_trm_portable.sh 2 384 # 2x A6000 + # or: 2 192 # 2x A5000 (->128 if OOM) + ``` + Target ~75% exact accuracy (official). ~18-28h on 2x A6000, ~24-36h on 2x A5000. + Saves one checkpoint per 5000 epochs (10 total) — keep all, the evolution analysis needs them. + +## After training +Preferred: `rsync` the whole run checkpoint dir (checkpoints/maze-.../pretrain_att_maze30x30_*/) +back to the lab box and run the existing analysis_2x2 queue there. The dir must include +`all_config.yaml` plus the `step_*` files. + +If diagnosing on the training machine, two caveats vs the Sudoku runs: +1. Maze uses the ATTENTION arch (not mlp_t). Verify diagnose_trm_joint.py's f_L/f_H JVP + closures call the attention L_level correctly; patch if the signature differs. +2. seq_len 900 (vs 97) makes per-sample JVP+QR ~9-10x slower. Use n=512 for the headline 2x2, + n=256 for the horizon sweep, k_lyap=8. + +## Sanity check before the long run +A 200-step smoke (epochs=200 eval_interval=200) should complete in minutes and confirm the +attention model + flash-attn + dataset load without OOM before committing to 50k epochs. diff --git a/research/flossing/maze_package/diagnose_trm_joint.py b/research/flossing/maze_package/diagnose_trm_joint.py new file mode 100644 index 0000000..160a7cb --- /dev/null +++ b/research/flossing/maze_package/diagnose_trm_joint.py @@ -0,0 +1,225 @@ +"""TRM Sudoku joint Lyapunov diagnostic — TRM version of diagnose_hrm_joint.py. + +Key differences from HRM: +- TRM has ONE shared L_level (H_layers config is "ignored") +- z_L update: z_L = L_level(z_L, z_H + input_embeddings) +- z_H update: z_H = L_level(z_H, z_L) ← same L_level! +- H_cycles=3, L_cycles=6 (vs HRM 2,2) + +Joint tangent block structure: +- L step: v_L_new = J · (v_L + v_H), v_H_new = v_H, J at (z_L + z_H + ie) +- H step: v_H_new = J' · (v_H + v_L), v_L_new = v_L, J' at (z_H + z_L) +J and J' share weights but evaluated at different points. +""" +from __future__ import annotations +import sys, os, yaml, math, argparse, json, time +from pathlib import Path +import numpy as np +import torch + +TRM_DIR = Path("/home/yurenh2/rrm/trm") +sys.path.insert(0, str(TRM_DIR)) + +from models.recursive_reasoning.trm import TinyRecursiveReasoningModel_ACTV1 + + +def load_model(ckpt_root: Path, ckpt_name: str, device: str): + cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) + arch_cfg = dict(cfg["arch"]) + train_meta = json.loads((Path(cfg["data_paths"][0]) / "train" / "dataset.json").read_text()) + arch_cfg.update(batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"], + vocab_size=train_meta["vocab_size"], + num_puzzle_identifiers=train_meta["num_puzzle_identifiers"]) + model = TinyRecursiveReasoningModel_ACTV1(arch_cfg) + sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) + stripped = {k.replace("_orig_mod.", "").replace("model.", ""): v for k, v in sd.items()} + missing, unexpected = model.load_state_dict(stripped, strict=False) + print(f"[load] missing={len(missing)} unexpected={len(unexpected)}") + if missing[:3]: print(f" sample missing: {missing[:3]}") + if unexpected[:3]: print(f" sample unexpected: {unexpected[:3]}") + model.to(device).eval() + return model, cfg, train_meta + + +def load_test_samples(data_path, n_total, shard_id, num_shards, seed): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "test" / "all__inputs.npy") + labels = np.load(data_path / "test" / "all__labels.npy") + pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy") + all_idx = rng.choice(len(inputs), size=n_total, replace=False) + shard_size = (n_total + num_shards - 1) // num_shards + s, e = shard_id * shard_size, min((shard_id + 1) * shard_size, n_total) + idx = all_idx[s:e] + return { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), + "labels": torch.from_numpy(labels[idx].astype(np.int32)), + "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)), + "idx": idx, + } + + +def jvp_through(f, x, v): + return torch.autograd.functional.jvp(f, x, v=v, create_graph=False, strict=False) + + +def run_diagnose_batch(model, batch, device, k_lyap, t_ons, seed): + inner = model.inner + cfg = inner.config + B = batch["inputs"].shape[0] + seq_full = cfg.seq_len + inner.puzzle_emb_len + hidden = cfg.hidden_size + D = seq_full * hidden + + z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype) + seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None) + input_embeddings = inner._input_embeddings(batch["inputs"].to(device), + batch["puzzle_identifiers"].to(device)) + + g = torch.Generator(device=device).manual_seed(seed) + Q0 = torch.randn(B, 2*D, k_lyap, device=device, dtype=torch.float32, generator=g) + Q, _ = torch.linalg.qr(Q0) + log_R_sum = torch.zeros(B, k_lyap, device=device, dtype=torch.float32) + n_lyap_steps = 0 + step_counter = 0 + + drift_zH_per_step, drift_zL_per_step = [], [] + halted_at = torch.zeros(B, dtype=torch.long, device=device) + q_halt_hist, q_continue_hist = [], [] + + for act_step in range(cfg.halt_max_steps): + z_H_prev = z_H.detach().clone() + z_L_prev = z_L.detach().clone() + + with torch.enable_grad(): + zH, zL = z_H.detach(), z_L.detach() + for _h in range(cfg.H_cycles): + # L cycles + for _l in range(cfg.L_cycles): + v_H_j = Q[:, :D, :] + v_L_j = Q[:, D:, :] + v_comb = v_H_j + v_L_j + new_v_L_cols = [] + f_L = lambda z: inner.L_level(z, zH + input_embeddings, **seq_info) + for i in range(k_lyap): + v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) + zL_new, Dv = jvp_through(f_L, zL, v_i) + new_v_L_cols.append(Dv.reshape(B, D).to(torch.float32)) + new_v_L = torch.stack(new_v_L_cols, dim=-1) + Q = torch.cat([v_H_j, new_v_L], dim=1) + zL = zL_new + step_counter += 1 + if step_counter % t_ons == 0: + Q, R = torch.linalg.qr(Q) + log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() + n_lyap_steps += 1 + + # H step (uses SAME L_level!) + v_H_j = Q[:, :D, :] + v_L_j = Q[:, D:, :] + v_comb = v_H_j + v_L_j + new_v_H_cols = [] + f_H = lambda z: inner.L_level(z, zL, **seq_info) + for i in range(k_lyap): + v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype) + zH_new, Dv = jvp_through(f_H, zH, v_i) + new_v_H_cols.append(Dv.reshape(B, D).to(torch.float32)) + new_v_H = torch.stack(new_v_H_cols, dim=-1) + Q = torch.cat([new_v_H, v_L_j], dim=1) + zH = zH_new + step_counter += 1 + if step_counter % t_ons == 0: + Q, R = torch.linalg.qr(Q) + log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() + n_lyap_steps += 1 + + z_H, z_L = zH, zL + + drift_zH_per_step.append((z_H - z_H_prev).float().flatten(1).norm(dim=1).cpu()) + drift_zL_per_step.append((z_L - z_L_prev).float().flatten(1).norm(dim=1).cpu()) + + with torch.no_grad(): + q_logits = inner.q_head(z_H[:, 0]).float() + q_halt, q_continue = q_logits[..., 0], q_logits[..., 1] + q_halt_hist.append(q_halt.cpu()); q_continue_hist.append(q_continue.cpu()) + new_halt = (q_halt > q_continue) & (halted_at == 0) + halted_at[new_halt] = act_step + 1 + output = inner.lm_head(z_H)[:, inner.puzzle_emb_len:].float() + final_logits = output + + lyap_spec = (log_R_sum / max(n_lyap_steps, 1)).cpu().numpy() + + with torch.no_grad(): + preds = final_logits.argmax(dim=-1) + labels = batch["labels"].to(device) + mask = labels > 0 + exact = ((preds == labels) | ~mask).all(dim=-1).cpu().float() + token_acc = ((preds == labels) & mask).sum(-1).float() / mask.sum(-1).float().clamp_min(1) + token_acc = token_acc.cpu() + + return { + "drift_zH": torch.stack(drift_zH_per_step, dim=1).numpy(), + "drift_zL": torch.stack(drift_zL_per_step, dim=1).numpy(), + "halted_at": halted_at.cpu().numpy(), + "q_halt": torch.stack(q_halt_hist, dim=1).numpy(), + "q_continue": torch.stack(q_continue_hist, dim=1).numpy(), + "lyap_spec": lyap_spec, + "exact_correct": exact.numpy(), + "token_acc": token_acc.numpy(), + } + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--ckpt-root", required=True) + ap.add_argument("--ckpt-name", default="step_13020") + ap.add_argument("--n-samples", type=int, default=512) + ap.add_argument("--shard-id", type=int, default=0) + ap.add_argument("--num-shards", type=int, default=1) + ap.add_argument("--batch-size", type=int, default=16) + ap.add_argument("--k-lyap", type=int, default=8) + ap.add_argument("--t-ons", type=int, default=1) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--out", default="diag_trm.npz") + args = ap.parse_args() + + device = "cuda" + model, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device) + print(f"loaded {args.ckpt_name}: hidden={model.inner.config.hidden_size}, " + f"seq_full={train_meta['seq_len'] + model.inner.puzzle_emb_len}, " + f"halt_max_steps={model.inner.config.halt_max_steps}, " + f"H={model.inner.config.H_cycles} L={model.inner.config.L_cycles}") + + test = load_test_samples(Path(cfg["data_paths"][0]), args.n_samples, args.shard_id, args.num_shards, args.seed) + n = len(test["inputs"]) + print(f"shard {args.shard_id}/{args.num_shards}: {n} samples") + + res = {k: [] for k in ["drift_zH","drift_zL","halted_at","q_halt","q_continue","lyap_spec","exact_correct","token_acc","idx"]} + t0 = time.time() + for s in range(0, n, args.batch_size): + e = min(s + args.batch_size, n) + batch = {k: test[k][s:e].to(device) for k in ["inputs","labels","puzzle_identifiers"]} + out = run_diagnose_batch(model, batch, device, args.k_lyap, args.t_ons, args.seed + s) + for k, v in out.items(): + res[k].append(v) + res["idx"].append(test["idx"][s:e]) + ls = out["lyap_spec"] + print(f" [{e}/{n}] dt={time.time()-t0:.1f}s exact={out['exact_correct'].mean():.3f} " + f"λ_1={ls[:,0].mean():+.4f} λ_{args.k_lyap}={ls[:,-1].mean():+.4f}", flush=True) + + saved = {} + for k, v in res.items(): + if not v: continue + try: saved[k] = np.concatenate(v, 0) + except ValueError: saved[k] = np.stack(v, 0) + np.savez_compressed(args.out, **saved) + succ = saved["exact_correct"] > 0.5 + print(f"\nN={len(succ)} acc={succ.mean():.4f}") + print(f"{'i':>3} {'all':>10} {'succ':>10} {'fail':>10} {'Δ':>9}") + for i in range(saved["lyap_spec"].shape[1]): + li = saved["lyap_spec"][:, i] + print(f"{i+1:>3} {li.mean():+10.4f} {li[succ].mean():+10.4f} {li[~succ].mean():+10.4f} {li[~succ].mean()-li[succ].mean():+9.4f}") + + +if __name__ == "__main__": + main() diff --git a/research/flossing/maze_package/launch_maze_trm.sh b/research/flossing/maze_package/launch_maze_trm.sh new file mode 100755 index 0000000..093bb1e --- /dev/null +++ b/research/flossing/maze_package/launch_maze_trm.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +# TRM Maze-Hard 30x30 official recipe, adapted for 2 GPUs. Run on dedicated training cards. +# Usage: bash launch_maze_trm.sh [NGPU] [GBS] +# 2x A6000 (48G): bash launch_maze_trm.sh 2 384 +# 2x A5000 (24G): bash launch_maze_trm.sh 2 192 (drop to 128 if OOM) +# 1x card: bash launch_maze_trm.sh 1 128 +set -eo pipefail + +NGPU="${1:-2}" +GBS="${2:-384}" +RUN_NAME="pretrain_att_maze30x30_${NGPU}gpu_gbs${GBS}" + +source /home/yurenh2/miniconda3/etc/profile.d/conda.sh +conda activate rrm +cd /home/yurenh2/rrm/trm +export WANDB_MODE=offline + +COMMON_ARGS=( + arch=trm + "data_paths=[/home/yurenh2/rrm/data/maze-30x30-hard-1k]" + "evaluators=[]" + epochs=50000 eval_interval=5000 + lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 + global_batch_size="${GBS}" + arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=4 + +run_name="${RUN_NAME}" ema=True + +checkpoint_every_eval=true +) + +LOG="/home/yurenh2/rrm/research/flossing/maze_${RUN_NAME}.log" + +if [[ "${NGPU}" -gt 1 ]]; then + nohup torchrun --nproc-per-node "${NGPU}" --rdzv_backend=c10d --rdzv_endpoint=localhost:0 \ + --nnodes=1 pretrain.py "${COMMON_ARGS[@]}" > "${LOG}" 2>&1 & +else + nohup python pretrain.py "${COMMON_ARGS[@]}" > "${LOG}" 2>&1 & +fi +echo "launched ${RUN_NAME} (pid $!), log: ${LOG}" +echo "checkpoints -> trm/checkpoints/maze-30x30-hard-1k.../${RUN_NAME}/ (one per 5000 epochs)" +echo "monitor: tail -f ${LOG} | grep -E 'accuracy|exact'" diff --git a/research/flossing/maze_package/launch_maze_trm_portable.sh b/research/flossing/maze_package/launch_maze_trm_portable.sh new file mode 100755 index 0000000..d801ceb --- /dev/null +++ b/research/flossing/maze_package/launch_maze_trm_portable.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash +# Portable TRM Maze-Hard launcher — run on any machine with the TRM repo + rrm conda env. +# +# Set these two paths for the target machine (env vars or edit here): +# TRM_DIR = path to the TinyRecursiveModels repo clone (contains pretrain.py) +# DATA_DIR = path to the maze-30x30-hard-1k dataset (from this bundle) +# CONDA_SH = path to conda.sh (default tries common locations) +# +# Usage: TRM_DIR=~/TinyRecursiveModels DATA_DIR=~/maze-30x30-hard-1k bash launch_maze_trm_portable.sh [NGPU] [GBS] +# 2x A6000 (48G): ... bash launch_maze_trm_portable.sh 2 384 +# 2x A5000 (24G): ... bash launch_maze_trm_portable.sh 2 192 (-> 128 if OOM) +set -eo pipefail + +TRM_DIR="${TRM_DIR:?set TRM_DIR to the TinyRecursiveModels repo path}" +DATA_DIR="${DATA_DIR:?set DATA_DIR to the maze-30x30-hard-1k dataset path}" +NGPU="${1:-2}" +GBS="${2:-384}" +RUN_NAME="pretrain_att_maze30x30_${NGPU}gpu_gbs${GBS}" + +# conda +CONDA_SH="${CONDA_SH:-}" +if [[ -z "${CONDA_SH}" ]]; then + for p in "$HOME/miniconda3/etc/profile.d/conda.sh" "$HOME/anaconda3/etc/profile.d/conda.sh" \ + "/opt/conda/etc/profile.d/conda.sh"; do + [[ -f "$p" ]] && CONDA_SH="$p" && break + done +fi +[[ -f "${CONDA_SH}" ]] && source "${CONDA_SH}" && conda activate "${CONDA_ENV:-rrm}" + +cd "${TRM_DIR}" +export WANDB_MODE=offline + +ARGS=( + arch=trm + "data_paths=[${DATA_DIR}]" + "evaluators=[]" + epochs=50000 eval_interval=5000 + lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 + global_batch_size="${GBS}" + arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=4 + +run_name="${RUN_NAME}" ema=True + +checkpoint_every_eval=true +) +LOG="maze_${RUN_NAME}.log" + +if [[ "${NGPU}" -gt 1 ]]; then + nohup torchrun --nproc-per-node "${NGPU}" --rdzv_backend=c10d --rdzv_endpoint=localhost:0 \ + --nnodes=1 pretrain.py "${ARGS[@]}" > "${LOG}" 2>&1 & +else + nohup python pretrain.py "${ARGS[@]}" > "${LOG}" 2>&1 & +fi +echo "launched ${RUN_NAME} (pid $!)" +echo "log: ${TRM_DIR}/${LOG}" +echo "ckpts: ${TRM_DIR}/checkpoints/maze-30x30-hard-1k.../${RUN_NAME}/ (1 per 5000 epochs)" +echo "watch: tail -f ${TRM_DIR}/${LOG} | grep -E 'exact|accuracy'" +echo +echo "When done: rsync the run's checkpoint dir back to the lab box for the diagnostic pipeline," +echo "or run diagnostics here (see TRANSFER_README.md, note the attention-arch + n=512 caveats)." diff --git a/research/flossing/maze_package/maze-30x30-hard-1k/identifiers.json b/research/flossing/maze_package/maze-30x30-hard-1k/identifiers.json new file mode 100644 index 0000000..16b2b6c --- /dev/null +++ b/research/flossing/maze_package/maze-30x30-hard-1k/identifiers.json @@ -0,0 +1 @@ +["<blank>"]
\ No newline at end of file diff --git a/research/flossing/maze_package/maze-30x30-hard-1k/test/dataset.json b/research/flossing/maze_package/maze-30x30-hard-1k/test/dataset.json new file mode 100644 index 0000000..e3314c3 --- /dev/null +++ b/research/flossing/maze_package/maze-30x30-hard-1k/test/dataset.json @@ -0,0 +1 @@ +{"pad_id": 0, "ignore_label_id": 0, "blank_identifier_id": 0, "vocab_size": 6, "seq_len": 900, "num_puzzle_identifiers": 1, "total_groups": 1000, "mean_puzzle_examples": 1.0, "total_puzzles": 1000, "sets": ["all"]}
\ No newline at end of file diff --git a/research/flossing/maze_package/maze-30x30-hard-1k/train/dataset.json b/research/flossing/maze_package/maze-30x30-hard-1k/train/dataset.json new file mode 100644 index 0000000..e3314c3 --- /dev/null +++ b/research/flossing/maze_package/maze-30x30-hard-1k/train/dataset.json @@ -0,0 +1 @@ +{"pad_id": 0, "ignore_label_id": 0, "blank_identifier_id": 0, "vocab_size": 6, "seq_len": 900, "num_puzzle_identifiers": 1, "total_groups": 1000, "mean_puzzle_examples": 1.0, "total_puzzles": 1000, "sets": ["all"]}
\ No newline at end of file diff --git a/research/flossing/maze_package/step7_interfloss.py b/research/flossing/maze_package/step7_interfloss.py new file mode 100644 index 0000000..3b8e4f0 --- /dev/null +++ b/research/flossing/maze_package/step7_interfloss.py @@ -0,0 +1,589 @@ +"""Step 7: Engelken-style interflossing. + +This is intentionally not a mixed objective. Ordinary task-training steps use +only the supervised ACT loss. Flossing episodes use only a Lyapunov-spectrum +conditioning loss, then task training resumes. + +Paper mapping: + - preflossing: run a floss-only episode before task training. + - interflossing: run short floss-only episodes at selected training steps. + - no persistent L_task + alpha * L_floss term is used here. +""" +from __future__ import annotations + +import argparse +import importlib +import json +import sys +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import yaml + + +HRM_DIR = Path("/home/yurenh2/rrm/hrm") +TRM_DIR = Path("/home/yurenh2/rrm/trm") + + +def import_stack(model_type: str): + repo_dir = HRM_DIR if model_type == "hrm" else TRM_DIR + sys.path.insert(0, str(repo_dir)) + if model_type == "hrm": + model_mod = importlib.import_module("models.hrm.hrm_act_v1") + model_cls = model_mod.HierarchicalReasoningModel_ACTV1 + else: + model_mod = importlib.import_module("models.recursive_reasoning.trm") + model_cls = model_mod.TinyRecursiveReasoningModel_ACTV1 + losses_mod = importlib.import_module("models.losses") + optim_mod = importlib.import_module("adam_atan2") + sparse_mod = importlib.import_module("models.sparse_embedding") + return model_cls, losses_mod.ACTLossHead, optim_mod.AdamATan2, sparse_mod.CastedSparseEmbeddingSignSGD_Distributed + + +def parse_step_list(text: str) -> set[int]: + if not text.strip(): + return set() + out = set() + for part in text.split(","): + part = part.strip() + if not part: + continue + out.add(int(part)) + return out + + +def build_interfloss_steps(args) -> set[int]: + steps = parse_step_list(args.interfloss_at) + if args.interfloss_every and args.interfloss_every > 0: + start = max(args.interfloss_start, 0) + stop = args.interfloss_stop if args.interfloss_stop >= 0 else args.train_steps + stop = min(stop, args.train_steps) + steps.update(range(start, stop + 1, args.interfloss_every)) + return steps + + +def load_model(model_type: str, ckpt_root: Path, ckpt_name: str, device: str, batch_size_override: int | None = None): + model_cls, loss_head_cls, adam_cls, sparse_cls = import_stack(model_type) + cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text()) + arch_cfg = dict(cfg["arch"]) + data_path = Path(cfg.get("data_path") or cfg["data_paths"][0]) + train_meta = json.loads((data_path / "train" / "dataset.json").read_text()) + arch_cfg.update( + batch_size=batch_size_override or cfg["global_batch_size"], + seq_len=train_meta["seq_len"], + vocab_size=train_meta["vocab_size"], + num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], + causal=False, + ) + cfg["data_path"] = str(data_path) + with torch.device(device): + base = model_cls(arch_cfg) + head = loss_head_cls(base, loss_type=arch_cfg["loss"]["loss_type"]) + if ckpt_name != "__random__": + sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True) + stripped = {k.replace("_orig_mod.", ""): v for k, v in sd.items()} + missing, unexpected = head.load_state_dict(stripped, strict=False) + print(f"[load {ckpt_name}] missing={len(missing)} unexpected={len(unexpected)}") + else: + print("[load __random__] random initialization from config") + return head, base, cfg, adam_cls, sparse_cls + + +def jvp_train(f, x, v): + return torch.autograd.functional.jvp(f, x, v=v, create_graph=True, strict=False) + + +def compute_joint_lyap_spec(model_type, base, batch, k_lyap, lyap_act_steps, device, seed, lyap_start_act=0): + inner = base.inner + cfg = inner.config + bsz = batch["inputs"].shape[0] + seq_full = cfg.seq_len + inner.puzzle_emb_len + hidden = cfg.hidden_size + dim = seq_full * hidden + + z_h = inner.H_init.unsqueeze(0).expand(bsz, seq_full, hidden).clone().to(inner.forward_dtype) + z_l = inner.L_init.unsqueeze(0).expand(bsz, seq_full, hidden).clone().to(inner.forward_dtype) + seq_info = {"cos_sin": inner.rotary_emb() if hasattr(inner, "rotary_emb") else None} + input_embeddings = inner._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) + + # Optional late-window measurement: first move to a later recursive state + # without differentiating through the warmup trajectory. This regularizes + # local late-stage stability instead of penalizing useful early expansion. + warmup_acts = min(max(lyap_start_act, 0), cfg.halt_max_steps) + if warmup_acts > 0: + with torch.no_grad(): + for _act in range(warmup_acts): + for _h in range(cfg.H_cycles): + for _l in range(cfg.L_cycles): + z_l = inner.L_level(z_l, z_h + input_embeddings, **seq_info) + if model_type == "trm": + z_h = inner.L_level(z_h, z_l, **seq_info) + else: + z_h = inner.H_level(z_h, z_l, **seq_info) + z_h = z_h.detach() + z_l = z_l.detach() + + gen = torch.Generator(device=device).manual_seed(seed) + q0 = torch.randn(bsz, 2 * dim, k_lyap, device=device, dtype=torch.float32, generator=gen) + q, _ = torch.linalg.qr(q0) + log_r_sum = torch.zeros(bsz, k_lyap, device=device, dtype=torch.float32) + n_steps = 0 + + n_act = min(lyap_act_steps, max(cfg.halt_max_steps - warmup_acts, 1)) + for _act in range(n_act): + for _h in range(cfg.H_cycles): + for _l in range(cfg.L_cycles): + v_h = q[:, :dim, :] + v_l = q[:, dim:, :] + v_comb = v_h + v_l + new_v_l_cols = [] + f_l = lambda z: inner.L_level(z, z_h + input_embeddings, **seq_info) + for i in range(k_lyap): + v_i = v_comb[:, :, i].reshape(bsz, seq_full, hidden).to(inner.forward_dtype) + z_l_new, d_v = jvp_train(f_l, z_l, v_i) + new_v_l_cols.append(d_v.reshape(bsz, dim).to(torch.float32)) + q = torch.cat([v_h, torch.stack(new_v_l_cols, dim=-1)], dim=1) + z_l = z_l_new + q, r = torch.linalg.qr(q) + log_r_sum = log_r_sum + r.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() + n_steps += 1 + + v_h = q[:, :dim, :] + v_l = q[:, dim:, :] + v_comb = v_h + v_l + new_v_h_cols = [] + if model_type == "trm": + f_h = lambda z: inner.L_level(z, z_l, **seq_info) + else: + f_h = lambda z: inner.H_level(z, z_l, **seq_info) + for i in range(k_lyap): + v_i = v_comb[:, :, i].reshape(bsz, seq_full, hidden).to(inner.forward_dtype) + z_h_new, d_v = jvp_train(f_h, z_h, v_i) + new_v_h_cols.append(d_v.reshape(bsz, dim).to(torch.float32)) + q = torch.cat([torch.stack(new_v_h_cols, dim=-1), v_l], dim=1) + z_h = z_h_new + q, r = torch.linalg.qr(q) + log_r_sum = log_r_sum + r.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log() + n_steps += 1 + + return log_r_sum / max(n_steps, 1) + + +def floss_loss_from_spec(spec, mode: str, lambda_star: float): + if mode == "engelken_l2": + return (spec ** 2).mean(), spec + if mode == "spectrum_cf": + excess = (spec - lambda_star).clamp_min(0.0) + return (excess ** 2).mean(), excess + if mode == "volume_cf": + volume = spec.mean(dim=1) + excess = (volume - lambda_star).clamp_min(0.0) + return (excess ** 2).mean(), excess + if mode == "top1_cf": + excess = (spec[:, 0] - lambda_star).clamp_min(0.0) + return (excess ** 2).mean(), excess + raise ValueError(f"unknown floss mode: {mode}") + + +def load_train_batches(data_path: Path, batch_size: int, n_iters: int, seed: int = 0): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "train" / "all__inputs.npy") + labels = np.load(data_path / "train" / "all__labels.npy") + pid = np.load(data_path / "train" / "all__puzzle_identifiers.npy") + n = len(inputs) + for _ in range(n_iters): + idx = rng.choice(n, size=batch_size, replace=False) + yield { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), + "labels": torch.from_numpy(labels[idx].astype(np.int32)), + "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)), + } + + +def sample_replay_batch(data_path: Path, n_samples: int, seed: int): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "train" / "all__inputs.npy") + labels = np.load(data_path / "train" / "all__labels.npy") + pid = np.load(data_path / "train" / "all__puzzle_identifiers.npy") + idx = rng.choice(len(inputs), size=n_samples, replace=False) + return { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)), + "labels": torch.from_numpy(labels[idx].astype(np.int32)), + "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)), + } + + +def move_batch(batch: dict[str, torch.Tensor], device: str): + return {k: v.to(device) for k, v in batch.items()} + + +def rollout_logits(base, batch, device): + with torch.device(device): + carry = base.initial_carry(batch) + for _ in range(base.config.halt_max_steps): + carry, outputs = base(carry=carry, batch=batch) + return outputs["logits"] + + +def build_kl_replay(args, base, data_path, device, episode_idx): + if args.kl_beta <= 0 or args.kl_replay_size <= 0: + return None + + replay = sample_replay_batch( + data_path, + n_samples=args.kl_replay_size, + seed=args.seed + 200000 + episode_idx, + ) + teacher_chunks = [] + base.eval() + with torch.no_grad(): + for start in range(0, args.kl_replay_size, args.kl_batch_size): + end = min(start + args.kl_batch_size, args.kl_replay_size) + batch = move_batch({k: v[start:end] for k, v in replay.items()}, device) + logits = rollout_logits(base, batch, device) + teacher_chunks.append(logits.detach().to(torch.float32).cpu()) + + replay["teacher_logits"] = torch.cat(teacher_chunks, dim=0) + replay["mask"] = replay["labels"] > 0 + return replay + + +def kl_preservation_loss(args, base, replay, step, device): + if replay is None: + return torch.zeros((), device=device) + + n_replay = replay["inputs"].shape[0] + batch_size = min(args.kl_batch_size, n_replay) + start = (step * batch_size) % n_replay + if start + batch_size <= n_replay: + idx = torch.arange(start, start + batch_size) + else: + idx = torch.cat([torch.arange(start, n_replay), torch.arange(0, start + batch_size - n_replay)]) + + batch = move_batch( + { + "inputs": replay["inputs"][idx], + "labels": replay["labels"][idx], + "puzzle_identifiers": replay["puzzle_identifiers"][idx], + }, + device, + ) + teacher_logits = replay["teacher_logits"][idx].to(device) + mask = replay["mask"][idx].to(device) + was_training = base.training + base.eval() + student_logits = rollout_logits(base, batch, device).to(torch.float32) + if was_training: + base.train() + set_puzzle_embedding_mode(base, args.train_puzzle_emb) + temp = args.kl_temperature + student_logp = F.log_softmax(student_logits / temp, dim=-1) + teacher_p = F.softmax(teacher_logits / temp, dim=-1) + kl_per_token = F.kl_div(student_logp, teacher_p, reduction="none").sum(dim=-1) * (temp ** 2) + if mask.any(): + return kl_per_token[mask].mean() + return kl_per_token.mean() + + +def evaluate(head, base, data_path, n_samples, batch_size, device, seed=42): + rng = np.random.default_rng(seed) + inputs = np.load(data_path / "test" / "all__inputs.npy") + labels = np.load(data_path / "test" / "all__labels.npy") + pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy") + idx_all = rng.choice(len(inputs), size=n_samples, replace=False) + head.eval() + correct = 0 + token_correct = 0 + token_total = 0 + for start in range(0, n_samples, batch_size): + end = min(start + batch_size, n_samples) + idx = idx_all[start:end] + batch = { + "inputs": torch.from_numpy(inputs[idx].astype(np.int32)).to(device), + "labels": torch.from_numpy(labels[idx].astype(np.int32)).to(device), + "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)).to(device), + } + with torch.no_grad(): + with torch.device(device): + carry = base.initial_carry(batch) + for _ in range(base.config.halt_max_steps): + carry, outputs = base(carry=carry, batch=batch) + preds = outputs["logits"].argmax(dim=-1) + mask = batch["labels"] > 0 + exact = ((preds == batch["labels"]) | ~mask).all(dim=-1).float() + correct += exact.sum().item() + token_correct += ((preds == batch["labels"]) & mask).sum().item() + token_total += mask.sum().item() + return correct / n_samples, token_correct / max(token_total, 1) + + +def write_log(path: str, log: dict): + Path(path).write_text(json.dumps(log, indent=2)) + + +def freeze_puzzle_embedding(base): + base.inner.puzzle_emb.eval() + + +def set_puzzle_embedding_mode(base, train_puzzle_emb: bool): + if train_puzzle_emb: + base.inner.puzzle_emb.train() + else: + freeze_puzzle_embedding(base) + + +def make_optimizers(args, base, head, adam_cls, sparse_cls, lr: float, weight_decay: float, train_puzzle_emb: bool): + optimizers = [] + if train_puzzle_emb and getattr(base.inner.config, "puzzle_emb_ndim", 0) > 0: + optimizers.append( + sparse_cls( + base.inner.puzzle_emb.buffers(), + lr=lr if args.puzzle_emb_lr is None else args.puzzle_emb_lr, + weight_decay=args.puzzle_emb_weight_decay, + world_size=1, + ) + ) + optimizers.append(adam_cls(head.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=weight_decay)) + return optimizers + + +def optim_zero_grad(optimizers): + for optim in optimizers: + optim.zero_grad(set_to_none=True) + + +def optim_step(optimizers): + for optim in optimizers: + optim.step() + + +def run_floss_episode(args, head, base, adam_cls, data_path, device, log, episode_idx, train_step): + print( + f"\n=== Floss episode {episode_idx} at train_step={train_step}: " + f"{args.floss_steps} steps, mode={args.floss_mode}, lr={args.floss_lr} ===", + flush=True, + ) + optimizers = make_optimizers( + args, base, head, adam_cls, args.sparse_cls, + lr=args.floss_lr, weight_decay=0.0, train_puzzle_emb=False, + ) + replay = build_kl_replay(args, base, data_path, device, episode_idx) + train_iter = load_train_batches( + data_path, + args.floss_batch_size, + args.floss_steps, + seed=args.seed + 100000 + episode_idx * 1000, + ) + episode = {"episode": episode_idx, "train_step": train_step, "steps": []} + t0 = time.time() + + for step, batch in enumerate(train_iter): + batch = {k: v.to(device) for k, v in batch.items()} + head.train() + set_puzzle_embedding_mode(base, False) + spec = compute_joint_lyap_spec( + args.model, + base, + batch, + k_lyap=args.k_lyap, + lyap_act_steps=args.lyap_act_steps, + device=device, + seed=args.seed + episode_idx * 10000 + step, + lyap_start_act=args.lyap_start_act, + ) + optim_zero_grad(optimizers) + floss_loss, excess = floss_loss_from_spec(spec, args.floss_mode, args.lambda_star) + floss_loss.backward() + kl_loss = kl_preservation_loss(args, base, replay, step, device) + if args.kl_beta > 0: + (args.kl_beta * kl_loss).backward() + torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0) + optim_step(optimizers) + + detached = spec.detach() + total_loss = floss_loss.detach() + args.kl_beta * kl_loss.detach() + rec = { + "step": step, + "loss": float(total_loss.item()), + "floss_loss": float(floss_loss.item()), + "kl_loss": float(kl_loss.item()), + "lyap1_mean": float(detached[:, 0].mean().item()), + "lyap1_max": float(detached[:, 0].max().item()), + "lyap_mean": float(detached.mean().item()), + "volume_mean": float(detached.mean(dim=1).mean().item()), + "volume_max": float(detached.mean(dim=1).max().item()), + "frac_active": float((excess.detach() > 0).float().mean().item()), + } + episode["steps"].append(rec) + if step % args.floss_log_every == 0 or step == args.floss_steps - 1: + print( + f" F[{step:>4}/{args.floss_steps}] dt={time.time() - t0:.1f}s " + f"loss={rec['loss']:.6f} floss={rec['floss_loss']:.6f} " + f"kl={rec['kl_loss']:.6f} lyap1={rec['lyap1_mean']:+.4f} " + f"vol={rec['volume_mean']:+.4f} active={rec['frac_active']:.2f}", + flush=True, + ) + + log["floss_episodes"].append(episode) + if args.eval_after_floss: + acc, tok_acc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print(f" >> FLOSS EVAL train_step={train_step}: exact_acc={acc:.4f}", flush=True) + log["evals"].append( + {"kind": "after_floss", "train_step": train_step, "episode": episode_idx, "acc": acc, "tok_acc": tok_acc} + ) + write_log(args.out, log) + + +def run_task_step(args, head, base, batch, optimizers, device): + batch = {k: v.to(device) for k, v in batch.items()} + head.train() + set_puzzle_embedding_mode(base, args.train_puzzle_emb) + with torch.device(device): + carry = base.initial_carry(batch) + loss_sum = 0.0 + n_loss = 0 + for _ in range(base.config.halt_max_steps): + carry, loss, _metrics, _outputs, all_finish = head(return_keys=[], carry=carry, batch=batch) + loss_sum = loss_sum + loss + n_loss += 1 + if all_finish: + break + sup_loss = loss_sum / max(n_loss, 1) / batch["inputs"].shape[0] + optim_zero_grad(optimizers) + sup_loss.backward() + torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0) + optim_step(optimizers) + return sup_loss + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", choices=["hrm", "trm"], required=True) + parser.add_argument("--ckpt-root", required=True) + parser.add_argument("--ckpt-name", required=True, + help="Checkpoint file name, or __random__ to initialize from config without loading weights.") + parser.add_argument("--train-steps", type=int, default=10000) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--task-batch-size", type=int, default=None, + help="Supervised task microbatch size. Defaults to --batch-size.") + parser.add_argument("--floss-batch-size", type=int, default=None, + help="Flossing microbatch size. Defaults to --batch-size.") + parser.add_argument("--train-lr", type=float, default=1e-5) + parser.add_argument("--floss-lr", type=float, default=1e-4) + parser.add_argument("--floss-steps", type=int, default=500) + parser.add_argument("--interfloss-at", default="0,500") + parser.add_argument("--interfloss-every", type=int, default=0, + help="If >0, also run floss episodes periodically every N task optimizer steps.") + parser.add_argument("--interfloss-start", type=int, default=0, + help="First task optimizer step for periodic interfloss.") + parser.add_argument("--interfloss-stop", type=int, default=-1, + help="Last task optimizer step for periodic interfloss. -1 means train_steps.") + parser.add_argument("--floss-mode", choices=["engelken_l2", "spectrum_cf", "volume_cf", "top1_cf"], default="engelken_l2") + parser.add_argument("--lambda-star", type=float, default=0.0) + parser.add_argument("--k-lyap", type=int, default=8) + parser.add_argument("--lyap-act-steps", type=int, default=4) + parser.add_argument("--lyap-start-act", type=int, default=0, + help="Warm up this many ACT steps before measuring/flossing the Lyapunov window.") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--eval-every", type=int, default=1000) + parser.add_argument("--eval-n", type=int, default=512) + parser.add_argument("--eval-batch-size", type=int, default=32) + parser.add_argument("--floss-log-every", type=int, default=10) + parser.add_argument("--eval-after-floss", action=argparse.BooleanOptionalAction, default=True) + parser.add_argument("--kl-beta", type=float, default=0.0, + help="Episode-start replay-logit KL weight during floss-only steps.") + parser.add_argument("--kl-replay-size", type=int, default=64) + parser.add_argument("--kl-batch-size", type=int, default=8) + parser.add_argument("--kl-temperature", type=float, default=1.0) + parser.add_argument("--init-seed", type=int, default=None, + help="Torch seed used before model construction. Use this for matched from-scratch runs.") + parser.add_argument("--train-puzzle-emb", action=argparse.BooleanOptionalAction, default=False, + help="Train sparse puzzle embeddings. Requires --batch-size to match the model local embedding batch.") + parser.add_argument("--puzzle-emb-lr", type=float, default=None, + help="Sparse puzzle embedding LR. Defaults to current phase LR.") + parser.add_argument("--puzzle-emb-weight-decay", type=float, default=1.0) + parser.add_argument("--out", default="step7_interfloss_log.json") + args = parser.parse_args() + if args.task_batch_size is None: + args.task_batch_size = args.batch_size + if args.floss_batch_size is None: + args.floss_batch_size = args.batch_size + args.batch_size = args.task_batch_size + + device = "cuda" + if args.init_seed is not None: + torch.manual_seed(args.init_seed) + np.random.seed(args.init_seed) + interfloss_steps = build_interfloss_steps(args) + head, base, cfg, adam_cls, sparse_cls = load_model( + args.model, + Path(args.ckpt_root), + args.ckpt_name, + device, + batch_size_override=args.task_batch_size if args.train_puzzle_emb else None, + ) + args.sparse_cls = sparse_cls + data_path = Path(cfg["data_path"]) + + print(f"\n=== Initial eval (loaded {args.ckpt_name}) ===") + acc0, tok0 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print(f" initial: exact_acc={acc0:.4f} token_acc={tok0:.4f}", flush=True) + + log = { + "args": {k: v for k, v in vars(args).items() if k != "sparse_cls"}, + "initial_acc": acc0, + "initial_tok_acc": tok0, + "interfloss_steps": sorted(interfloss_steps), + "task_steps": [], + "floss_episodes": [], + "evals": [{"kind": "initial", "train_step": 0, "acc": acc0, "tok_acc": tok0}], + } + write_log(args.out, log) + + task_optimizers = make_optimizers( + args, base, head, adam_cls, sparse_cls, + lr=args.train_lr, weight_decay=cfg["weight_decay"], train_puzzle_emb=args.train_puzzle_emb, + ) + train_iter = load_train_batches(data_path, args.task_batch_size, args.train_steps, seed=args.seed) + episode_idx = 0 + t0 = time.time() + + for train_step, batch in enumerate(train_iter): + if train_step in interfloss_steps: + run_floss_episode(args, head, base, adam_cls, data_path, device, log, episode_idx, train_step) + episode_idx += 1 + + sup_loss = run_task_step(args, head, base, batch, task_optimizers, device) + rec = {"train_step": train_step + 1, "sup_loss": float(sup_loss.item())} + log["task_steps"].append(rec) + if train_step % 50 == 0 or train_step == args.train_steps - 1: + print( + f" T[{train_step + 1:>5}/{args.train_steps}] dt={time.time() - t0:.1f}s " + f"sup={rec['sup_loss']:.4f}", + flush=True, + ) + + if (train_step + 1) % args.eval_every == 0: + acc, tok_acc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print(f" >> TASK EVAL @ step {train_step + 1}: exact_acc={acc:.4f} delta={acc - acc0:+.4f}", flush=True) + log["evals"].append({"kind": "task", "train_step": train_step + 1, "acc": acc, "tok_acc": tok_acc}) + write_log(args.out, log) + + if args.train_steps in interfloss_steps: + run_floss_episode(args, head, base, adam_cls, data_path, device, log, episode_idx, args.train_steps) + + acc_f, tok_f = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device) + print("\n=== Final eval ===") + print(f" initial={acc0:.4f} final={acc_f:.4f} delta={acc_f - acc0:+.4f}", flush=True) + log["final_acc"] = acc_f + log["final_tok_acc"] = tok_f + log["evals"].append({"kind": "final", "train_step": args.train_steps, "acc": acc_f, "tok_acc": tok_f}) + write_log(args.out, log) + print(f"log -> {args.out}") + + +if __name__ == "__main__": + main() |
