summaryrefslogtreecommitdiff
path: root/research/flossing/step6_prefloss.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/step6_prefloss.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/step6_prefloss.py')
-rw-r--r--research/flossing/step6_prefloss.py310
1 files changed, 310 insertions, 0 deletions
diff --git a/research/flossing/step6_prefloss.py b/research/flossing/step6_prefloss.py
new file mode 100644
index 0000000..af84de0
--- /dev/null
+++ b/research/flossing/step6_prefloss.py
@@ -0,0 +1,310 @@
+"""Step 6: Preflossing experiment (Engelken-style separate phases).
+
+Phase 1 — Pure flossing: only optimize L_floss, no task loss.
+ engelken: L = (1/k) Σ λ_i² (push all top-k toward 0, two-sided)
+ cf: L = (1/k) Σ max(0, λ_i)² (only push positive λ toward 0)
+
+Phase 2 — Pure task training: standard HRM ACT loss, no flossing.
+
+Baseline mode (--prefloss-steps 0): skip phase 1, go straight to task training.
+
+Key fix vs step3: lyap_act_steps defaults to halt_max_steps (16 for HRM).
+"""
+from __future__ import annotations
+import sys, os, yaml, json, math, time, argparse
+from pathlib import Path
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+HRM_DIR = Path("/home/yurenh2/rrm/hrm")
+sys.path.insert(0, str(HRM_DIR))
+
+from models.hrm.hrm_act_v1 import HierarchicalReasoningModel_ACTV1
+from models.losses import ACTLossHead
+from adam_atan2 import AdamATan2
+
+
+def load_model(ckpt_root: Path, ckpt_name: str, device: str):
+ cfg = yaml.safe_load((ckpt_root / "all_config.yaml").read_text())
+ arch_cfg = dict(cfg["arch"])
+ train_meta = json.loads((Path(cfg["data_path"]) / "train" / "dataset.json").read_text())
+ arch_cfg.update(batch_size=cfg["global_batch_size"], seq_len=train_meta["seq_len"],
+ vocab_size=train_meta["vocab_size"],
+ num_puzzle_identifiers=train_meta["num_puzzle_identifiers"], causal=False)
+ base = HierarchicalReasoningModel_ACTV1(arch_cfg)
+ head = ACTLossHead(base, loss_type=arch_cfg["loss"]["loss_type"])
+ sd = torch.load(ckpt_root / ckpt_name, map_location="cpu", weights_only=True)
+ stripped = {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
+ missing, unexpected = head.load_state_dict(stripped, strict=False)
+ print(f"[load {ckpt_name}] missing={len(missing)} unexpected={len(unexpected)}")
+ head.to(device)
+ return head, base, cfg, train_meta
+
+
+def jvp_train(f, x, v):
+ return torch.autograd.functional.jvp(f, x, v=v, create_graph=True, strict=False)
+
+
+def compute_joint_lyap_spec(base, batch, k_lyap, lyap_act_steps, device, seed):
+ inner = base.inner
+ cfg = inner.config
+ B = batch["inputs"].shape[0]
+ seq_full = cfg.seq_len + inner.puzzle_emb_len
+ hidden = cfg.hidden_size
+ D = seq_full * hidden
+
+ z_H = inner.H_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
+ z_L = inner.L_init.unsqueeze(0).expand(B, seq_full, hidden).clone().to(inner.forward_dtype)
+ seq_info = dict(cos_sin=inner.rotary_emb() if hasattr(inner, "rotary_emb") else None)
+ input_embeddings = inner._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
+
+ g = torch.Generator(device=device).manual_seed(seed)
+ Q0 = torch.randn(B, 2*D, k_lyap, device=device, dtype=torch.float32, generator=g)
+ Q, _ = torch.linalg.qr(Q0)
+ log_R_sum = torch.zeros(B, k_lyap, device=device, dtype=torch.float32)
+ n_steps = 0
+
+ n_act = min(lyap_act_steps, cfg.halt_max_steps)
+ for _act in range(n_act):
+ for _h in range(cfg.H_cycles):
+ for _l in range(cfg.L_cycles):
+ v_H_j = Q[:, :D, :]
+ v_L_j = Q[:, D:, :]
+ v_comb = v_H_j + v_L_j
+ new_v_L_cols = []
+ f_L = lambda z: inner.L_level(z, z_H + input_embeddings, **seq_info)
+ for i in range(k_lyap):
+ v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype)
+ z_L_new, Dv = jvp_train(f_L, z_L, v_i)
+ new_v_L_cols.append(Dv.reshape(B, D).to(torch.float32))
+ new_v_L = torch.stack(new_v_L_cols, dim=-1)
+ Q = torch.cat([v_H_j, new_v_L], dim=1)
+ z_L = z_L_new
+ Q, R = torch.linalg.qr(Q)
+ log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log()
+ n_steps += 1
+ v_H_j = Q[:, :D, :]
+ v_L_j = Q[:, D:, :]
+ v_comb = v_H_j + v_L_j
+ new_v_H_cols = []
+ f_H = lambda z: inner.H_level(z, z_L, **seq_info)
+ for i in range(k_lyap):
+ v_i = v_comb[:, :, i].reshape(B, seq_full, hidden).to(inner.forward_dtype)
+ z_H_new, Dv = jvp_train(f_H, z_H, v_i)
+ new_v_H_cols.append(Dv.reshape(B, D).to(torch.float32))
+ new_v_H = torch.stack(new_v_H_cols, dim=-1)
+ Q = torch.cat([new_v_H, v_L_j], dim=1)
+ z_H = z_H_new
+ Q, R = torch.linalg.qr(Q)
+ log_R_sum = log_R_sum + R.diagonal(dim1=-2, dim2=-1).abs().clamp_min(1e-30).log()
+ n_steps += 1
+
+ return log_R_sum / max(n_steps, 1)
+
+
+def load_train_batches(data_path: Path, batch_size: int, n_iters: int, seed: int = 0):
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "train" / "all__inputs.npy")
+ labels = np.load(data_path / "train" / "all__labels.npy")
+ pid = np.load(data_path / "train" / "all__puzzle_identifiers.npy")
+ N = len(inputs)
+ for _ in range(n_iters):
+ idx = rng.choice(N, size=batch_size, replace=False)
+ yield {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)),
+ "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)),
+ }
+
+
+def evaluate(head, base, data_path, n_samples, batch_size, device, seed=42):
+ rng = np.random.default_rng(seed)
+ inputs = np.load(data_path / "test" / "all__inputs.npy")
+ labels = np.load(data_path / "test" / "all__labels.npy")
+ pid = np.load(data_path / "test" / "all__puzzle_identifiers.npy")
+ idx_all = rng.choice(len(inputs), size=n_samples, replace=False)
+ head.eval()
+ correct = 0; token_correct = 0; token_total = 0
+ for s in range(0, n_samples, batch_size):
+ e = min(s + batch_size, n_samples)
+ idx = idx_all[s:e]
+ batch = {
+ "inputs": torch.from_numpy(inputs[idx].astype(np.int32)).to(device),
+ "labels": torch.from_numpy(labels[idx].astype(np.int32)).to(device),
+ "puzzle_identifiers": torch.from_numpy(pid[idx].astype(np.int32)).to(device),
+ }
+ with torch.no_grad():
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+ for _ in range(base.config.halt_max_steps):
+ carry, outputs = base(carry=carry, batch=batch)
+ preds = outputs["logits"].argmax(dim=-1)
+ mask = batch["labels"] > 0
+ exact = ((preds == batch["labels"]) | ~mask).all(dim=-1).float()
+ correct += exact.sum().item()
+ token_correct += ((preds == batch["labels"]) & mask).sum().item()
+ token_total += mask.sum().item()
+ return correct / n_samples, token_correct / max(token_total, 1)
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--ckpt-root", required=True)
+ ap.add_argument("--ckpt-name", default="step_13020")
+ ap.add_argument("--prefloss-steps", type=int, default=500,
+ help="Phase 1: pure flossing steps. 0 = skip (baseline).")
+ ap.add_argument("--train-steps", type=int, default=3000,
+ help="Phase 2: pure task training steps.")
+ ap.add_argument("--floss-mode", choices=["engelken", "cf", "volume_cf"], default="engelken",
+ help="engelken: Σλ_i² (two-sided). cf: Σmax(0,λ_i)² (one-sided hinge). "
+ "volume_cf: max(0, mean_i λ_i)² over the measured top-k spectrum.")
+ ap.add_argument("--batch-size", type=int, default=8)
+ ap.add_argument("--floss-lr", type=float, default=1e-4,
+ help="LR for flossing phase (Engelken uses higher LR)")
+ ap.add_argument("--train-lr", type=float, default=1e-5,
+ help="LR for task training phase")
+ ap.add_argument("--k-lyap", type=int, default=2)
+ ap.add_argument("--lyap-act-steps", type=int, default=16,
+ help="ACT steps for Lyapunov computation (default=halt_max_steps)")
+ ap.add_argument("--seed", type=int, default=42)
+ ap.add_argument("--eval-every", type=int, default=100)
+ ap.add_argument("--eval-n", type=int, default=512)
+ ap.add_argument("--eval-batch-size", type=int, default=32)
+ ap.add_argument("--out", default="step6_log.json")
+ args = ap.parse_args()
+
+ device = "cuda"
+ head, base, cfg, train_meta = load_model(Path(args.ckpt_root), args.ckpt_name, device)
+ data_path = Path(cfg["data_path"])
+
+ print(f"\n=== Initial eval (loaded {args.ckpt_name}) ===")
+ acc0, tacc0 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f" initial: exact_acc = {acc0:.4f} token_acc = {tacc0:.4f}")
+
+ log = {"args": vars(args), "initial_acc": acc0, "initial_tok_acc": tacc0,
+ "phase1_steps": [], "phase1_evals": [], "phase2_steps": [], "phase2_evals": []}
+
+ global_step = 0
+
+ # ========== PHASE 1: Pure flossing ==========
+ if args.prefloss_steps > 0:
+ print(f"\n=== Phase 1: Pure {args.floss_mode} flossing ({args.prefloss_steps} steps, lr={args.floss_lr}) ===")
+ floss_optim = AdamATan2(head.parameters(), lr=args.floss_lr, betas=(0.9, 0.95), weight_decay=0.0)
+ floss_iter = load_train_batches(data_path, args.batch_size, args.prefloss_steps, seed=args.seed)
+ t0 = time.time()
+
+ for step, batch in enumerate(floss_iter):
+ batch = {k: v.to(device) for k, v in batch.items()}
+ head.train()
+ base.inner.puzzle_emb.eval()
+ for p in base.inner.puzzle_emb.parameters():
+ p.requires_grad_(False)
+
+ lyap_spec = compute_joint_lyap_spec(
+ base, batch, k_lyap=args.k_lyap,
+ lyap_act_steps=args.lyap_act_steps, device=device,
+ seed=args.seed + step,
+ )
+
+ if args.floss_mode == "engelken":
+ floss_loss = (lyap_spec ** 2).mean()
+ elif args.floss_mode == "volume_cf":
+ floss_loss = (lyap_spec.mean(dim=1).clamp_min(0.0) ** 2).mean()
+ else:
+ floss_loss = (lyap_spec.clamp_min(0.0) ** 2).mean()
+
+ floss_optim.zero_grad(set_to_none=True)
+ floss_loss.backward()
+ torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0)
+ floss_optim.step()
+
+ lyap1 = lyap_spec[:, 0].detach()
+ rec = {
+ "step": step, "floss_loss": float(floss_loss.item()),
+ "lyap1_mean": float(lyap1.mean().item()),
+ "lyap1_max": float(lyap1.max().item()),
+ "lyap_all_mean": float(lyap_spec.detach().mean().item()),
+ "lyap_volume_mean": float(lyap_spec.detach().mean(dim=1).mean().item()),
+ "lyap_volume_max": float(lyap_spec.detach().mean(dim=1).max().item()),
+ }
+ log["phase1_steps"].append(rec)
+ if step % 10 == 0 or step == args.prefloss_steps - 1:
+ print(f" P1[{step:>4}/{args.prefloss_steps}] dt={time.time()-t0:.1f}s "
+ f"floss={rec['floss_loss']:.6f} "
+ f"λ1={rec['lyap1_mean']:+.4f} max={rec['lyap1_max']:+.4f} "
+ f"λ_all={rec['lyap_all_mean']:+.4f}", flush=True)
+
+ if (step + 1) % args.eval_every == 0:
+ acc, tacc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f" >> P1 EVAL @ step {step+1}: exact_acc={acc:.4f} (Δ={acc-acc0:+.4f})", flush=True)
+ log["phase1_evals"].append({"step": step + 1, "acc": acc, "tok_acc": tacc})
+
+ acc_p1, tacc_p1 = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f" Phase 1 final: exact_acc={acc_p1:.4f} (Δ from init: {acc_p1-acc0:+.4f})")
+ log["phase1_final_acc"] = acc_p1
+ log["phase1_evals"].append({"step": args.prefloss_steps, "acc": acc_p1, "tok_acc": tacc_p1})
+ global_step = args.prefloss_steps
+ else:
+ print("\n=== Phase 1 skipped (baseline mode) ===")
+ log["phase1_final_acc"] = acc0
+
+ # ========== PHASE 2: Pure task training ==========
+ print(f"\n=== Phase 2: Pure task training ({args.train_steps} steps, lr={args.train_lr}) ===")
+ train_optim = AdamATan2(head.parameters(), lr=args.train_lr, betas=(0.9, 0.95),
+ weight_decay=cfg["weight_decay"])
+ train_iter = load_train_batches(data_path, args.batch_size, args.train_steps,
+ seed=args.seed + 10000)
+ t0 = time.time()
+ acc_ref = log.get("phase1_final_acc", acc0)
+
+ for step, batch in enumerate(train_iter):
+ batch = {k: v.to(device) for k, v in batch.items()}
+ head.train()
+ base.inner.puzzle_emb.eval()
+ for p in base.inner.puzzle_emb.parameters():
+ p.requires_grad_(False)
+
+ with torch.device(device):
+ carry = base.initial_carry(batch)
+ sup_loss_sum = 0.0
+ n_loss = 0
+ for _ in range(base.config.halt_max_steps):
+ carry, l, metrics, _, all_finish = head(return_keys=[], carry=carry, batch=batch)
+ sup_loss_sum = sup_loss_sum + l
+ n_loss += 1
+ if all_finish:
+ break
+ sup_loss = sup_loss_sum / max(n_loss, 1) / args.batch_size
+
+ train_optim.zero_grad(set_to_none=True)
+ sup_loss.backward()
+ torch.nn.utils.clip_grad_norm_([p for p in head.parameters() if p.requires_grad], 1.0)
+ train_optim.step()
+
+ rec = {"step": step, "sup_loss": float(sup_loss.item())}
+ log["phase2_steps"].append(rec)
+ if step % 50 == 0 or step == args.train_steps - 1:
+ print(f" P2[{step:>4}/{args.train_steps}] dt={time.time()-t0:.1f}s "
+ f"sup={rec['sup_loss']:.4f}", flush=True)
+
+ if (step + 1) % args.eval_every == 0:
+ acc, tacc = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f" >> P2 EVAL @ step {step+1}: exact_acc={acc:.4f} "
+ f"(Δ from init: {acc-acc0:+.4f}, Δ from P1: {acc-acc_ref:+.4f})", flush=True)
+ log["phase2_evals"].append({"step": global_step + step + 1, "acc": acc, "tok_acc": tacc})
+
+ acc_f, tacc_f = evaluate(head, base, data_path, args.eval_n, args.eval_batch_size, device)
+ print(f"\n=== Final eval ===")
+ print(f" initial: {acc0:.4f} phase1_end: {log.get('phase1_final_acc', acc0):.4f} "
+ f"final: {acc_f:.4f} (total Δ: {acc_f-acc0:+.4f})")
+ log["final_acc"] = acc_f
+ log["final_tok_acc"] = tacc_f
+ log["phase2_evals"].append({"step": global_step + args.train_steps, "acc": acc_f, "tok_acc": tacc_f})
+
+ Path(args.out).write_text(json.dumps(log, indent=2))
+ print(f"log → {args.out}")
+
+
+if __name__ == "__main__":
+ main()