diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/train_local_ce.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/train_local_ce.py')
| -rw-r--r-- | ep_run/train_local_ce.py | 580 |
1 files changed, 580 insertions, 0 deletions
diff --git a/ep_run/train_local_ce.py b/ep_run/train_local_ce.py new file mode 100644 index 0000000..b8f790a --- /dev/null +++ b/ep_run/train_local_ce.py @@ -0,0 +1,580 @@ +"""Local CE exit training — each block gets a vocab-space CE loss via shared unembedding. + +Each block l computes: + z_l = W_U @ T_l(h_l) (local logits via shared unembedding + optional translator) + L_l = λ_gt * CE(z_l, y) + λ_kd * τ² * KL(sg(p_L^τ) || p_l^τ) + +Forward weights updated per-block via local CE gradient (intra-block only). +No inter-block chain rule. Fused attention backward within each block. + +This replaces the hidden-space MSE target-matching that failed at scale. +""" +import argparse +import json +import math +import pickle +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model_local import LocalGPTConfig, LocalBlock, LocalLinear, _make_ln +from factorized_exit import FactorizedExitHead, ExactParallelExitHead +from local_layers import initialize_dfa_block_targets, apply_dfa_block_update + + +def get_batch(split, data_dir, block_size, batch_size, device, n_pred=1): + fn = "train.bin" if split == "train" else "val.bin" + data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r") + ix = torch.randint(len(data) - block_size - n_pred, (batch_size,)) + x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) + if n_pred == 1: + y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix]) + return x.to(device, non_blocking=True), y.to(device, non_blocking=True) + # n_pred > 1: targets shape (B, T, n_pred). Y[..., k-1] = next-k target. + y_multi = torch.stack([ + torch.stack([ + torch.from_numpy(data[i + k : i + k + block_size].astype(np.int64)) + for k in range(1, n_pred + 1) + ], dim=-1) + for i in ix + ]) + return x.to(device, non_blocking=True), y_multi.to(device, non_blocking=True) + + +class LowRankTranslator(nn.Module): + """T_l(h) = h + A @ B @ h + b. Low-rank affine residual translator.""" + def __init__(self, d_model, rank=32): + super().__init__() + self.A = nn.Parameter(torch.zeros(d_model, rank)) + self.B = nn.Parameter(torch.zeros(rank, d_model)) + self.bias = nn.Parameter(torch.zeros(d_model)) + nn.init.normal_(self.A, std=0.01) + nn.init.normal_(self.B, std=0.01) + + def forward(self, h): + return h + h @ self.B.T @ self.A.T + self.bias + + +class LocalCETransformer(nn.Module): + """Transformer with per-block local CE exits via shared unembedding.""" + + def __init__(self, config: LocalGPTConfig, translator_rank: int = 0, n_pred_tokens: int = 1, + shared_blocks: bool = False): + super().__init__() + self.config = config + self.n_pred_tokens = n_pred_tokens + self.shared_blocks = shared_blocks + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + self.pos_emb = nn.Embedding(config.block_size, config.n_embd) + self.drop = nn.Dropout(config.dropout) + if shared_blocks: + # Universal Transformer: one block applied n_layer times. + # All entries point to the SAME module — gradient accumulates from all "depths". + shared = LocalBlock(config) + self.blocks = nn.ModuleList([shared for _ in range(config.n_layer)]) + else: + self.blocks = nn.ModuleList([LocalBlock(config) for _ in range(config.n_layer)]) + self.ln_f = _make_ln(config) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Auxiliary unembedding heads for next-2..next-N prediction (multi-token training). + # Used only as gradient-source heads at training time; inference still uses self.head. + if n_pred_tokens > 1: + self.aux_heads = nn.ModuleList([ + nn.Linear(config.n_embd, config.vocab_size, bias=False) + for _ in range(n_pred_tokens - 1) + ]) + else: + self.aux_heads = None + + # Per-block translators (logit lens = rank 0 = identity) + if translator_rank > 0: + self.translators = nn.ModuleList([ + LowRankTranslator(config.n_embd, translator_rank) + for _ in range(config.n_layer) + ]) + else: + self.translators = None + + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith("o_proj.weight") or pn.endswith("mlp.proj.weight"): + nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) + + def _init_weights(self, m): + if isinstance(m, (nn.Linear, LocalLinear)): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + if getattr(m, "bias", None) is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + + def forward_activations(self, idx): + B, T = idx.shape + pos = torch.arange(T, device=idx.device) + h = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) + activations = [h] + for block in self.blocks: + h = block(h) + activations.append(h) + return activations + + def local_logits(self, h, layer_idx): + """h → local logits via optional translator + shared unembedding.""" + if self.translators is not None: + h = self.translators[layer_idx](h) + return F.linear(h, self.head.weight) # shared W_U, no separate head + + def final_logits(self, h): + return self.head(self.ln_f(h)) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--run_name", type=str, required=True) + p.add_argument("--seed", type=int, default=1337) + p.add_argument("--data_dir", type=str, default="data/shakespeare_char") + p.add_argument("--out_dir", type=str, default="runs_local") + p.add_argument("--block_size", type=int, default=256) + p.add_argument("--batch_size", type=int, default=64) + p.add_argument("--n_layer", type=int, default=6) + p.add_argument("--n_head", type=int, default=6) + p.add_argument("--n_embd", type=int, default=384) + p.add_argument("--dropout", type=float, default=0.2) + p.add_argument("--max_iters", type=int, default=5000) + p.add_argument("--warmup_iters", type=int, default=100) + p.add_argument("--max_lr", type=float, default=1e-3) + p.add_argument("--min_lr", type=float, default=1e-4) + p.add_argument("--eval_interval", type=int, default=250) + p.add_argument("--eval_iters", type=int, default=100) + p.add_argument("--log_interval", type=int, default=50) + p.add_argument("--attn_mode", type=str, default="softmax") + p.add_argument("--translator_rank", type=int, default=0, help="0=identity (logit lens), >0=low-rank affine") + p.add_argument("--kd_weight", type=float, default=1.0, help="weight for KL distillation from final layer") + p.add_argument("--kd_temp", type=float, default=2.0, help="temperature for KD") + p.add_argument("--gt_weight", type=float, default=1.0, help="weight for ground-truth CE") + p.add_argument("--nbr_weight", type=float, default=0.0, help="weight for neighbor KL (sg(p_{l+1}) || p_l)") + p.add_argument("--layer_weighting", type=str, default="uniform", choices=["uniform", "linear"], + help="per-layer loss weight: uniform=all 1.0, linear=l/L") + p.add_argument("--bp_free_exit", type=str, default="none", + choices=["none", "dense", "hybrid", "parallel_only", "parallel_gold", "parallel_topmass"], + help="BP-free exit: none=W_U^T, dense/hybrid=compressor, parallel_*=exact parallel term") + p.add_argument("--exit_rank", type=int, default=128, help="rank for BP-free exit compressor") + p.add_argument("--exit_rank_exact", type=int, default=32, help="exact rank for hybrid compressor") + p.add_argument("--exit_topk", type=int, default=8, help="top-k for hybrid compressor") + p.add_argument("--exit_residual_rank", type=int, default=32, + help="residual_rank for ExactParallelExitHead (parallel_gold/topmass): code dim for h-perp residual") + p.add_argument("--intra_block_method", type=str, default="bp", choices=["bp", "fa", "sign_sym", "dfa_block"], + help="intra-block: bp=W^T, fa=seq random B, sign_sym=sign(W)·rescale, dfa_block=direct from block-output-error") + p.add_argument("--mlp_topk", type=int, default=0, + help="if >0, apply hard k-WTA to MLP hidden activation (4*n_embd dim)") + p.add_argument("--resid_topk", type=int, default=0, + help="if >0, apply hard k-WTA to residual stream output of each block (n_embd dim)") + p.add_argument("--vq_codes", type=int, default=0, + help="if >0, apply directional VQ to residual stream at each block (K codebook entries, frozen)") + p.add_argument("--subspace_rank", type=int, default=0, + help="if >0, project residual stream to fixed r-dim orthonormal subspace at each block") + p.add_argument("--subspace_per_layer", action="store_true", + help="use DIFFERENT random Q per layer (ablation: tests if shared Q is necessary)") + p.add_argument("--fa_init_sign", action="store_true", + help="init FA's fixed B as sign(W_init)*rescale instead of random (frozen sign_sym)") + p.add_argument("--shared_blocks", action="store_true", + help="Universal Transformer: all blocks share the same parameters (single block applied n_layer times)") + p.add_argument("--fa_init", type=str, default="gaussian", + choices=["gaussian", "orthogonal", "ortho_he", "sparse"], + help="FA's fixed B init mode (gaussian=Lillicrap, orthogonal=JL-isometric, ortho_he=He-init backward, sparse=structured)") + p.add_argument("--fa_sparse_k", type=int, default=0, + help="for fa_init=sparse: non-zero entries per row (0 = auto = in_features/16)") + p.add_argument("--gated_blocks", action="store_true", + help="Path IV: learned per-block residual gates (α_attn, α_mlp). Lets useless layers self-deactivate.") + p.add_argument("--progression_targets", action="store_true", + help="Path I: each block l predicts next-(l+1) token (progressive prediction horizons per layer)") + p.add_argument("--weight_normalize", action="store_true", + help="Meta-PCN style WN: after each optimizer step, normalize LocalLinear's W by (sqrt(m)+sqrt(n))*std(W) to keep ||W||_2 ~= 1") + p.add_argument("--pc_inference", type=int, default=0, + help="Predictive coding inference steps T (T=0 disables PC mode, uses standard local CE)") + p.add_argument("--pc_inference_lr", type=float, default=0.1, + help="Inference step size η for PC z updates") + p.add_argument("--pc_top_weight", type=float, default=1.0, + help="Weight of top-down CE term in PC energy F") + p.add_argument("--fa_grape", action="store_true", + help="GrAPE: per-step JVP-based cosine alignment of FA's B toward true Jacobian (Caillon et al. 2026)") + p.add_argument("--fa_grape_lr", type=float, default=0.01, + help="Learning rate for GrAPE B alignment update") + p.add_argument("--fa_grape_n_probe", type=int, default=32, + help="Number of probe samples for JVP rank-1 Jacobian estimate") + p.add_argument("--save_ckpt", action="store_true", + help="save final model state to run_dir/ckpt.pt for downstream probing") + p.add_argument("--n_pred_tokens", type=int, default=1, + help="multi-token prediction: predict next-1..next-N (N=1 disables, default)") + p.add_argument("--aux_weight", type=float, default=0.3, + help="weight for aux next-k losses (k=2..N). Primary next-1 always weight 1.0.") + args = p.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + data_dir = Path(args.data_dir) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + vocab_size = meta["vocab_size"] + + run_dir = Path(args.out_dir) / args.run_name + run_dir.mkdir(parents=True, exist_ok=True) + log_path = run_dir / "log.jsonl" + log_path.write_text("") + + cfg = LocalGPTConfig( + block_size=args.block_size, vocab_size=vocab_size, + n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, + dropout=args.dropout, attn_mode=args.attn_mode, + method=args.intra_block_method, + fuse_attn_local=True, ste_gelu=True, ln_mode="center_scale", + mlp_topk=args.mlp_topk, resid_topk=args.resid_topk, + vq_codes=args.vq_codes, subspace_rank=args.subspace_rank, + fa_init_mode=args.fa_init, fa_sparse_k=args.fa_sparse_k, + gated_blocks=args.gated_blocks, + fa_grape=args.fa_grape, fa_grape_n_probe=args.fa_grape_n_probe, + ) + model = LocalCETransformer(cfg, translator_rank=args.translator_rank, + n_pred_tokens=args.n_pred_tokens, + shared_blocks=args.shared_blocks).to(device) + + # Frozen sign_sym: replace FA's random B with sign(W_init)*rescale, then freeze. + # B is still a fixed buffer (BP-free by definition B), just structured init. + if args.fa_init_sign and args.intra_block_method == "fa": + with torch.no_grad(): + for module in model.modules(): + if isinstance(module, LocalLinear) and module.method == "fa": + scale = module.weight.norm() / (module.weight.numel() ** 0.5 + 1e-8) + module.B.copy_(torch.sign(module.weight) * scale) + + # Per-layer different Q ablation: replace each block's shared-seed subspace + # with independently-seeded subspace (tests if shared Q is the mechanism) + if args.subspace_per_layer and args.subspace_rank > 0: + from model_local import FrozenSubspace + for i, block in enumerate(model.blocks): + block.subspace = FrozenSubspace(args.n_embd, args.subspace_rank, seed=1000 + i).to(device) + + # Initialize DFA-block targets if needed + if args.intra_block_method == "dfa_block": + initialize_dfa_block_targets(model, args.n_embd) + + # BP-free exit heads (one per block) + exit_heads = None + if args.bp_free_exit in ("dense", "hybrid"): + exit_heads = nn.ModuleList([ + FactorizedExitHead( + args.n_embd, vocab_size, mode=args.bp_free_exit, + rank=args.exit_rank, rank_exact=args.exit_rank_exact, topk=args.exit_topk, + ) for _ in range(cfg.n_layer) + ]).to(device) + elif args.bp_free_exit.startswith("parallel"): + exit_heads = nn.ModuleList([ + ExactParallelExitHead( + args.n_embd, vocab_size, mode=args.bp_free_exit, + residual_rank=args.exit_residual_rank, + ) for _ in range(cfg.n_layer) + ]).to(device) + + n_params = sum(p.numel() for p in model.parameters()) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr, weight_decay=0.1) + + t0 = time.time() + + def log(rec): + rec["t"] = time.time() - t0 + with open(log_path, "a") as f: + f.write(json.dumps(rec) + "\n") + + log({"event": "start", "method": "local_ce", "params": n_params, + "translator_rank": args.translator_rank, "config": vars(args)}) + print(f"[{args.run_name}] local_ce, params={n_params/1e6:.2f}M, translator_rank={args.translator_rank}") + + def lr_schedule(it): + if it < args.warmup_iters: + return args.max_lr * (it + 1) / (args.warmup_iters + 1) + decay = 0.5 * (1 + math.cos(math.pi * (it - args.warmup_iters) / + max(1, args.max_iters - args.warmup_iters))) + return args.min_lr + decay * (args.max_lr - args.min_lr) + + @torch.no_grad() + def eval_loss(): + model.eval() + losses = torch.zeros(args.eval_iters) + for k in range(args.eval_iters): + X, Y = get_batch("val", data_dir, args.block_size, args.batch_size, device) + acts = model.forward_activations(X) + logits = model.final_logits(acts[-1]) + loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1)) + losses[k] = loss.item() + model.train() + return losses.mean().item() + + model.train() + for it in range(args.max_iters + 1): + lr = lr_schedule(it) + for g in optimizer.param_groups: + g["lr"] = lr + + if it % args.eval_interval == 0 or it == args.max_iters: + val = eval_loss() + log({"event": "eval", "iter": it, "val_loss": val, "lr": lr}) + print(f"[{args.run_name}] iter {it:5d} val {val:.4f} lr {lr:.4g}") + + if it == args.max_iters: + break + + # Determine n_pred for batch fetch + # - n_pred_tokens > 1: multi-token MTP aux losses (each block predicts N targets via N heads) + # - progression_targets: each block l predicts next-(l+1) (so need n_pred = n_layer) + if args.n_pred_tokens > 1: + n_pred = args.n_pred_tokens + elif args.progression_targets: + n_pred = cfg.n_layer + else: + n_pred = 1 + + if n_pred > 1: + X, Y_multi = get_batch("train", data_dir, args.block_size, args.batch_size, device, + n_pred=n_pred) + Y = Y_multi[..., 0] # (B, T) — next-1 target for default + else: + X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) + Y_multi = None + + # ============================================================ + # PC mode: predictive coding inference + Hebbian-style updates + # (when --pc_inference T > 0, replaces standard local CE) + # ============================================================ + if args.pc_inference > 0: + optimizer.zero_grad() + Y_flat = Y.view(-1) + + # 1. Forward init (no autograd graph during init) + with torch.no_grad(): + init_acts = model.forward_activations(X) + # z[0] = embedding, clamped (no grad). z[1..L] = block outputs, evolve. + z = [init_acts[0].detach()] + for l in range(1, len(init_acts)): + z.append(init_acts[l].detach().clone().requires_grad_(True)) + + # 2. PC inference: T iterations of z updates via ∂F/∂z + # F = Σ_{l<L} (1/2) ||z_l - block_{l-1}(z_{l-1})||² / d + λ·CE(W_U @ z_L, y) + # Skip PE_L (Meta-PCN trick: CE replaces last-layer squared error) + # Use mean over hidden dim (per-token PE²) for scale-invariance. + for t in range(args.pc_inference): + F_energy = 0.0 + for l in range(1, len(z) - 1): # l = 1..L-1, skip PE_L + z_hat = model.blocks[l - 1](z[l - 1]) + pe = z[l] - z_hat + F_energy = F_energy + 0.5 * (pe ** 2).mean() # mean over (B,T,d) → scale-invariant + # Top-down: CE at z_L (replaces PE_L per Meta-PCN convention) + logits_top = F.linear(z[-1], model.head.weight) + CE_top = F.cross_entropy(logits_top.view(-1, vocab_size), Y_flat) + F_total = F_energy + args.pc_top_weight * CE_top + # Compute ∂F/∂z[1..L] (FA-flavored due to LocalLinear FA backward inside blocks) + grads = torch.autograd.grad(F_total, z[1:], create_graph=False, retain_graph=False) + # SGD update on z's + with torch.no_grad(): + new_z = [z[0]] + for i, g in enumerate(grads): + new_z.append((z[i + 1] - args.pc_inference_lr * g).detach().requires_grad_(True)) + z = new_z + + # 3. Weight update via per-block PE loss using converged z's + # For block l-1: minimize ||sg(z_l) - block_{l-1}(sg(z_{l-1}))||² + # backward gives FA-flavored W gradients (Hebbian-equivalent at equilibrium) + total_loss = 0.0 + for l in range(1, len(z)): + z_hat = model.blocks[l - 1](z[l - 1].detach()) + target = z[l].detach() + pe_loss = 0.5 * ((target - z_hat) ** 2).mean() + pe_loss.backward() + total_loss += pe_loss.item() + + # Final head + ln_f via CE on converged z[-1] + final_z = model.final_logits(z[-1].detach()) + head_loss = F.cross_entropy(final_z.view(-1, vocab_size), Y_flat) + head_loss.backward() + total_loss += head_loss.item() + + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + + # Optional WN + if args.weight_normalize: + with torch.no_grad(): + for module in model.modules(): + if isinstance(module, LocalLinear): + m, n = module.weight.shape + sigma_w = module.weight.std() + scale = (m ** 0.5 + n ** 0.5) * sigma_w + if scale > 1e-8: + module.weight.div_(scale) + + if it % args.log_interval == 0: + log({"event": "step", "iter": it, + "total_loss": total_loss / (cfg.n_layer + 1), + "head_loss": head_loss.item(), "lr": lr}) + continue + # ============================================================ + # End PC mode; standard local CE follows + # ============================================================ + + # Forward: compute all activations + activations = model.forward_activations(X) + + # Final logits (for KD teacher + eval) + with torch.no_grad(): + final_logits = model.final_logits(activations[-1].detach()) + teacher_probs = F.softmax(final_logits / args.kd_temp, dim=-1) + + # Per-block local CE losses (no inter-block gradient) + optimizer.zero_grad() + total_loss = 0.0 + Y_flat = Y.view(-1) + + # Neighbor KL teachers are computed on-the-fly inside the per-block loop + # (avoid pre-computing all 6 × (B,T,V) tensors which OOM at V=50k) + + for l in range(cfg.n_layer): + # Block l: h_l → block → h_{l+1} + h_l = activations[l] if l == 0 else activations[l].detach() + # For dfa_block mode, need h_lp1 with retain_grad to capture block-output-error + if args.intra_block_method == "dfa_block": + h_lp1 = model.blocks[l](h_l) + h_lp1.retain_grad() + else: + h_lp1 = model.blocks[l](h_l) + + # Path I: progression targets — block l predicts next-(l+1) instead of next-1 + if args.progression_targets and Y_multi is not None: + Y_block = Y_multi[..., l] # (B, T) — block l's specific target + else: + Y_block = Y # default: all blocks predict next-1 + Y_block_flat = Y_block.reshape(-1) + + # Local logits via shared unembedding (exact or BP-free) + if exit_heads is not None: + local_z = exit_heads[l](h_lp1, model.head.weight, Y_block) + else: + local_z = model.local_logits(h_lp1, l) + local_z_flat = local_z.view(-1, vocab_size) + + # Per-layer weight + if args.layer_weighting == "linear": + layer_w = (l + 1) / cfg.n_layer + else: + layer_w = 1.0 + + # Ground-truth CE (uses Y_block_flat which respects progression mode) + loss_gt = F.cross_entropy(local_z_flat, Y_block_flat) + + # KD from final layer (skip when both kd_weight and nbr_weight are 0 to save 3.3GB/block) + loss_kd = 0.0 + loss_nbr = 0.0 + if args.kd_weight > 0 or args.nbr_weight > 0: + local_log_probs = F.log_softmax(local_z / args.kd_temp, dim=-1) + if args.kd_weight > 0: + loss_kd = F.kl_div( + local_log_probs.view(-1, vocab_size), + teacher_probs.view(-1, vocab_size), + reduction="batchmean", + ) * (args.kd_temp ** 2) + # Neighbor KL: match next block's prediction (stop-grad), computed on-the-fly + if args.nbr_weight > 0 and l < cfg.n_layer - 1: + with torch.no_grad(): + nbr_z = model.local_logits(activations[l + 2].detach(), l + 1) + nbr_probs = F.softmax(nbr_z / args.kd_temp, dim=-1) + del nbr_z + loss_nbr = F.kl_div( + local_log_probs.view(-1, vocab_size), + nbr_probs.view(-1, vocab_size), + reduction="batchmean", + ) * (args.kd_temp ** 2) + del nbr_probs + del local_log_probs + + # Multi-token aux losses: predict next-2..next-N via aux_heads + # Each aux head provides an independent gradient direction (different W_k column space). + # Reuses the same exit_heads[l] (shared codebook) but with different shared_weight + targets. + loss_aux = 0.0 + if args.n_pred_tokens > 1 and args.aux_weight > 0 and model.aux_heads is not None: + for k_idx, aux_head in enumerate(model.aux_heads): + Y_k = Y_multi[..., k_idx + 1] # next-(k_idx+2) target + if exit_heads is not None: + z_k = exit_heads[l](h_lp1, aux_head.weight, Y_k) + else: + z_k = F.linear(h_lp1, aux_head.weight) + loss_k = F.cross_entropy(z_k.view(-1, vocab_size), Y_k.reshape(-1)) + loss_aux = loss_aux + loss_k + loss_aux = loss_aux * args.aux_weight / (args.n_pred_tokens - 1) + + block_loss = layer_w * ( + args.gt_weight * loss_gt + + args.kd_weight * loss_kd + + args.nbr_weight * loss_nbr + + loss_aux + ) + block_loss.backward() + + # For dfa_block: overwrite intra-block linears' .grad using block-output-error + if args.intra_block_method == "dfa_block" and h_lp1.grad is not None: + with torch.no_grad(): + apply_dfa_block_update(model.blocks[l], h_lp1.grad) + + total_loss += block_loss.item() + + # Also train head + ln_f via final CE + h_L_det = activations[-1].detach() + final_z = model.final_logits(h_L_det) + head_loss = F.cross_entropy(final_z.view(-1, vocab_size), Y_flat) + head_loss.backward() + total_loss += head_loss.item() + + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + + # GrAPE: per-step alignment of FA's B toward Jacobian via JVP probes (forward-only) + if args.fa_grape: + for module in model.modules(): + if isinstance(module, LocalLinear) and getattr(module, "_fa_grape", False): + module.grape_align_step(lr_b=args.fa_grape_lr) + + # Meta-PCN style weight normalization: rescale each LocalLinear's W to have ||W||_2 ~= 1 + # via random matrix theory bound ||W||_2 ~= (sqrt(m) + sqrt(n)) * std(W). + # Only normalizes LocalLinear W (the trained weight); leaves B (fixed buffer) untouched. + if args.weight_normalize: + with torch.no_grad(): + for module in model.modules(): + if isinstance(module, LocalLinear): + m, n = module.weight.shape + sigma_w = module.weight.std() + scale = (m ** 0.5 + n ** 0.5) * sigma_w + if scale > 1e-8: + module.weight.div_(scale) + + if it % args.log_interval == 0: + log({"event": "step", "iter": it, "total_loss": total_loss / (cfg.n_layer + 1), + "head_loss": head_loss.item(), "lr": lr}) + + if args.save_ckpt: + ckpt_path = run_dir / "ckpt.pt" + torch.save({ + "model_state": model.state_dict(), + "config": vars(cfg), + "args": vars(args), + "vocab_size": vocab_size, + }, ckpt_path) + log({"event": "save_ckpt", "path": str(ckpt_path)}) + print(f"[{args.run_name}] saved ckpt to {ckpt_path}") + + +if __name__ == "__main__": + main() |
