summaryrefslogtreecommitdiff
path: root/ep_run/train_local_ce.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/train_local_ce.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/train_local_ce.py')
-rw-r--r--ep_run/train_local_ce.py580
1 files changed, 580 insertions, 0 deletions
diff --git a/ep_run/train_local_ce.py b/ep_run/train_local_ce.py
new file mode 100644
index 0000000..b8f790a
--- /dev/null
+++ b/ep_run/train_local_ce.py
@@ -0,0 +1,580 @@
+"""Local CE exit training — each block gets a vocab-space CE loss via shared unembedding.
+
+Each block l computes:
+ z_l = W_U @ T_l(h_l) (local logits via shared unembedding + optional translator)
+ L_l = λ_gt * CE(z_l, y) + λ_kd * τ² * KL(sg(p_L^τ) || p_l^τ)
+
+Forward weights updated per-block via local CE gradient (intra-block only).
+No inter-block chain rule. Fused attention backward within each block.
+
+This replaces the hidden-space MSE target-matching that failed at scale.
+"""
+import argparse
+import json
+import math
+import pickle
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from model_local import LocalGPTConfig, LocalBlock, LocalLinear, _make_ln
+from factorized_exit import FactorizedExitHead, ExactParallelExitHead
+from local_layers import initialize_dfa_block_targets, apply_dfa_block_update
+
+
+def get_batch(split, data_dir, block_size, batch_size, device, n_pred=1):
+ fn = "train.bin" if split == "train" else "val.bin"
+ data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r")
+ ix = torch.randint(len(data) - block_size - n_pred, (batch_size,))
+ x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix])
+ if n_pred == 1:
+ y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix])
+ return x.to(device, non_blocking=True), y.to(device, non_blocking=True)
+ # n_pred > 1: targets shape (B, T, n_pred). Y[..., k-1] = next-k target.
+ y_multi = torch.stack([
+ torch.stack([
+ torch.from_numpy(data[i + k : i + k + block_size].astype(np.int64))
+ for k in range(1, n_pred + 1)
+ ], dim=-1)
+ for i in ix
+ ])
+ return x.to(device, non_blocking=True), y_multi.to(device, non_blocking=True)
+
+
+class LowRankTranslator(nn.Module):
+ """T_l(h) = h + A @ B @ h + b. Low-rank affine residual translator."""
+ def __init__(self, d_model, rank=32):
+ super().__init__()
+ self.A = nn.Parameter(torch.zeros(d_model, rank))
+ self.B = nn.Parameter(torch.zeros(rank, d_model))
+ self.bias = nn.Parameter(torch.zeros(d_model))
+ nn.init.normal_(self.A, std=0.01)
+ nn.init.normal_(self.B, std=0.01)
+
+ def forward(self, h):
+ return h + h @ self.B.T @ self.A.T + self.bias
+
+
+class LocalCETransformer(nn.Module):
+ """Transformer with per-block local CE exits via shared unembedding."""
+
+ def __init__(self, config: LocalGPTConfig, translator_rank: int = 0, n_pred_tokens: int = 1,
+ shared_blocks: bool = False):
+ super().__init__()
+ self.config = config
+ self.n_pred_tokens = n_pred_tokens
+ self.shared_blocks = shared_blocks
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
+ self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
+ self.drop = nn.Dropout(config.dropout)
+ if shared_blocks:
+ # Universal Transformer: one block applied n_layer times.
+ # All entries point to the SAME module — gradient accumulates from all "depths".
+ shared = LocalBlock(config)
+ self.blocks = nn.ModuleList([shared for _ in range(config.n_layer)])
+ else:
+ self.blocks = nn.ModuleList([LocalBlock(config) for _ in range(config.n_layer)])
+ self.ln_f = _make_ln(config)
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+
+ # Auxiliary unembedding heads for next-2..next-N prediction (multi-token training).
+ # Used only as gradient-source heads at training time; inference still uses self.head.
+ if n_pred_tokens > 1:
+ self.aux_heads = nn.ModuleList([
+ nn.Linear(config.n_embd, config.vocab_size, bias=False)
+ for _ in range(n_pred_tokens - 1)
+ ])
+ else:
+ self.aux_heads = None
+
+ # Per-block translators (logit lens = rank 0 = identity)
+ if translator_rank > 0:
+ self.translators = nn.ModuleList([
+ LowRankTranslator(config.n_embd, translator_rank)
+ for _ in range(config.n_layer)
+ ])
+ else:
+ self.translators = None
+
+ self.apply(self._init_weights)
+ for pn, p in self.named_parameters():
+ if pn.endswith("o_proj.weight") or pn.endswith("mlp.proj.weight"):
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Linear, LocalLinear)):
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
+ if getattr(m, "bias", None) is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Embedding):
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
+
+ def forward_activations(self, idx):
+ B, T = idx.shape
+ pos = torch.arange(T, device=idx.device)
+ h = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
+ activations = [h]
+ for block in self.blocks:
+ h = block(h)
+ activations.append(h)
+ return activations
+
+ def local_logits(self, h, layer_idx):
+ """h → local logits via optional translator + shared unembedding."""
+ if self.translators is not None:
+ h = self.translators[layer_idx](h)
+ return F.linear(h, self.head.weight) # shared W_U, no separate head
+
+ def final_logits(self, h):
+ return self.head(self.ln_f(h))
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument("--run_name", type=str, required=True)
+ p.add_argument("--seed", type=int, default=1337)
+ p.add_argument("--data_dir", type=str, default="data/shakespeare_char")
+ p.add_argument("--out_dir", type=str, default="runs_local")
+ p.add_argument("--block_size", type=int, default=256)
+ p.add_argument("--batch_size", type=int, default=64)
+ p.add_argument("--n_layer", type=int, default=6)
+ p.add_argument("--n_head", type=int, default=6)
+ p.add_argument("--n_embd", type=int, default=384)
+ p.add_argument("--dropout", type=float, default=0.2)
+ p.add_argument("--max_iters", type=int, default=5000)
+ p.add_argument("--warmup_iters", type=int, default=100)
+ p.add_argument("--max_lr", type=float, default=1e-3)
+ p.add_argument("--min_lr", type=float, default=1e-4)
+ p.add_argument("--eval_interval", type=int, default=250)
+ p.add_argument("--eval_iters", type=int, default=100)
+ p.add_argument("--log_interval", type=int, default=50)
+ p.add_argument("--attn_mode", type=str, default="softmax")
+ p.add_argument("--translator_rank", type=int, default=0, help="0=identity (logit lens), >0=low-rank affine")
+ p.add_argument("--kd_weight", type=float, default=1.0, help="weight for KL distillation from final layer")
+ p.add_argument("--kd_temp", type=float, default=2.0, help="temperature for KD")
+ p.add_argument("--gt_weight", type=float, default=1.0, help="weight for ground-truth CE")
+ p.add_argument("--nbr_weight", type=float, default=0.0, help="weight for neighbor KL (sg(p_{l+1}) || p_l)")
+ p.add_argument("--layer_weighting", type=str, default="uniform", choices=["uniform", "linear"],
+ help="per-layer loss weight: uniform=all 1.0, linear=l/L")
+ p.add_argument("--bp_free_exit", type=str, default="none",
+ choices=["none", "dense", "hybrid", "parallel_only", "parallel_gold", "parallel_topmass"],
+ help="BP-free exit: none=W_U^T, dense/hybrid=compressor, parallel_*=exact parallel term")
+ p.add_argument("--exit_rank", type=int, default=128, help="rank for BP-free exit compressor")
+ p.add_argument("--exit_rank_exact", type=int, default=32, help="exact rank for hybrid compressor")
+ p.add_argument("--exit_topk", type=int, default=8, help="top-k for hybrid compressor")
+ p.add_argument("--exit_residual_rank", type=int, default=32,
+ help="residual_rank for ExactParallelExitHead (parallel_gold/topmass): code dim for h-perp residual")
+ p.add_argument("--intra_block_method", type=str, default="bp", choices=["bp", "fa", "sign_sym", "dfa_block"],
+ help="intra-block: bp=W^T, fa=seq random B, sign_sym=sign(W)·rescale, dfa_block=direct from block-output-error")
+ p.add_argument("--mlp_topk", type=int, default=0,
+ help="if >0, apply hard k-WTA to MLP hidden activation (4*n_embd dim)")
+ p.add_argument("--resid_topk", type=int, default=0,
+ help="if >0, apply hard k-WTA to residual stream output of each block (n_embd dim)")
+ p.add_argument("--vq_codes", type=int, default=0,
+ help="if >0, apply directional VQ to residual stream at each block (K codebook entries, frozen)")
+ p.add_argument("--subspace_rank", type=int, default=0,
+ help="if >0, project residual stream to fixed r-dim orthonormal subspace at each block")
+ p.add_argument("--subspace_per_layer", action="store_true",
+ help="use DIFFERENT random Q per layer (ablation: tests if shared Q is necessary)")
+ p.add_argument("--fa_init_sign", action="store_true",
+ help="init FA's fixed B as sign(W_init)*rescale instead of random (frozen sign_sym)")
+ p.add_argument("--shared_blocks", action="store_true",
+ help="Universal Transformer: all blocks share the same parameters (single block applied n_layer times)")
+ p.add_argument("--fa_init", type=str, default="gaussian",
+ choices=["gaussian", "orthogonal", "ortho_he", "sparse"],
+ help="FA's fixed B init mode (gaussian=Lillicrap, orthogonal=JL-isometric, ortho_he=He-init backward, sparse=structured)")
+ p.add_argument("--fa_sparse_k", type=int, default=0,
+ help="for fa_init=sparse: non-zero entries per row (0 = auto = in_features/16)")
+ p.add_argument("--gated_blocks", action="store_true",
+ help="Path IV: learned per-block residual gates (α_attn, α_mlp). Lets useless layers self-deactivate.")
+ p.add_argument("--progression_targets", action="store_true",
+ help="Path I: each block l predicts next-(l+1) token (progressive prediction horizons per layer)")
+ p.add_argument("--weight_normalize", action="store_true",
+ help="Meta-PCN style WN: after each optimizer step, normalize LocalLinear's W by (sqrt(m)+sqrt(n))*std(W) to keep ||W||_2 ~= 1")
+ p.add_argument("--pc_inference", type=int, default=0,
+ help="Predictive coding inference steps T (T=0 disables PC mode, uses standard local CE)")
+ p.add_argument("--pc_inference_lr", type=float, default=0.1,
+ help="Inference step size η for PC z updates")
+ p.add_argument("--pc_top_weight", type=float, default=1.0,
+ help="Weight of top-down CE term in PC energy F")
+ p.add_argument("--fa_grape", action="store_true",
+ help="GrAPE: per-step JVP-based cosine alignment of FA's B toward true Jacobian (Caillon et al. 2026)")
+ p.add_argument("--fa_grape_lr", type=float, default=0.01,
+ help="Learning rate for GrAPE B alignment update")
+ p.add_argument("--fa_grape_n_probe", type=int, default=32,
+ help="Number of probe samples for JVP rank-1 Jacobian estimate")
+ p.add_argument("--save_ckpt", action="store_true",
+ help="save final model state to run_dir/ckpt.pt for downstream probing")
+ p.add_argument("--n_pred_tokens", type=int, default=1,
+ help="multi-token prediction: predict next-1..next-N (N=1 disables, default)")
+ p.add_argument("--aux_weight", type=float, default=0.3,
+ help="weight for aux next-k losses (k=2..N). Primary next-1 always weight 1.0.")
+ args = p.parse_args()
+
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ data_dir = Path(args.data_dir)
+ with open(data_dir / "meta.pkl", "rb") as f:
+ meta = pickle.load(f)
+ vocab_size = meta["vocab_size"]
+
+ run_dir = Path(args.out_dir) / args.run_name
+ run_dir.mkdir(parents=True, exist_ok=True)
+ log_path = run_dir / "log.jsonl"
+ log_path.write_text("")
+
+ cfg = LocalGPTConfig(
+ block_size=args.block_size, vocab_size=vocab_size,
+ n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd,
+ dropout=args.dropout, attn_mode=args.attn_mode,
+ method=args.intra_block_method,
+ fuse_attn_local=True, ste_gelu=True, ln_mode="center_scale",
+ mlp_topk=args.mlp_topk, resid_topk=args.resid_topk,
+ vq_codes=args.vq_codes, subspace_rank=args.subspace_rank,
+ fa_init_mode=args.fa_init, fa_sparse_k=args.fa_sparse_k,
+ gated_blocks=args.gated_blocks,
+ fa_grape=args.fa_grape, fa_grape_n_probe=args.fa_grape_n_probe,
+ )
+ model = LocalCETransformer(cfg, translator_rank=args.translator_rank,
+ n_pred_tokens=args.n_pred_tokens,
+ shared_blocks=args.shared_blocks).to(device)
+
+ # Frozen sign_sym: replace FA's random B with sign(W_init)*rescale, then freeze.
+ # B is still a fixed buffer (BP-free by definition B), just structured init.
+ if args.fa_init_sign and args.intra_block_method == "fa":
+ with torch.no_grad():
+ for module in model.modules():
+ if isinstance(module, LocalLinear) and module.method == "fa":
+ scale = module.weight.norm() / (module.weight.numel() ** 0.5 + 1e-8)
+ module.B.copy_(torch.sign(module.weight) * scale)
+
+ # Per-layer different Q ablation: replace each block's shared-seed subspace
+ # with independently-seeded subspace (tests if shared Q is the mechanism)
+ if args.subspace_per_layer and args.subspace_rank > 0:
+ from model_local import FrozenSubspace
+ for i, block in enumerate(model.blocks):
+ block.subspace = FrozenSubspace(args.n_embd, args.subspace_rank, seed=1000 + i).to(device)
+
+ # Initialize DFA-block targets if needed
+ if args.intra_block_method == "dfa_block":
+ initialize_dfa_block_targets(model, args.n_embd)
+
+ # BP-free exit heads (one per block)
+ exit_heads = None
+ if args.bp_free_exit in ("dense", "hybrid"):
+ exit_heads = nn.ModuleList([
+ FactorizedExitHead(
+ args.n_embd, vocab_size, mode=args.bp_free_exit,
+ rank=args.exit_rank, rank_exact=args.exit_rank_exact, topk=args.exit_topk,
+ ) for _ in range(cfg.n_layer)
+ ]).to(device)
+ elif args.bp_free_exit.startswith("parallel"):
+ exit_heads = nn.ModuleList([
+ ExactParallelExitHead(
+ args.n_embd, vocab_size, mode=args.bp_free_exit,
+ residual_rank=args.exit_residual_rank,
+ ) for _ in range(cfg.n_layer)
+ ]).to(device)
+
+ n_params = sum(p.numel() for p in model.parameters())
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr, weight_decay=0.1)
+
+ t0 = time.time()
+
+ def log(rec):
+ rec["t"] = time.time() - t0
+ with open(log_path, "a") as f:
+ f.write(json.dumps(rec) + "\n")
+
+ log({"event": "start", "method": "local_ce", "params": n_params,
+ "translator_rank": args.translator_rank, "config": vars(args)})
+ print(f"[{args.run_name}] local_ce, params={n_params/1e6:.2f}M, translator_rank={args.translator_rank}")
+
+ def lr_schedule(it):
+ if it < args.warmup_iters:
+ return args.max_lr * (it + 1) / (args.warmup_iters + 1)
+ decay = 0.5 * (1 + math.cos(math.pi * (it - args.warmup_iters) /
+ max(1, args.max_iters - args.warmup_iters)))
+ return args.min_lr + decay * (args.max_lr - args.min_lr)
+
+ @torch.no_grad()
+ def eval_loss():
+ model.eval()
+ losses = torch.zeros(args.eval_iters)
+ for k in range(args.eval_iters):
+ X, Y = get_batch("val", data_dir, args.block_size, args.batch_size, device)
+ acts = model.forward_activations(X)
+ logits = model.final_logits(acts[-1])
+ loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1))
+ losses[k] = loss.item()
+ model.train()
+ return losses.mean().item()
+
+ model.train()
+ for it in range(args.max_iters + 1):
+ lr = lr_schedule(it)
+ for g in optimizer.param_groups:
+ g["lr"] = lr
+
+ if it % args.eval_interval == 0 or it == args.max_iters:
+ val = eval_loss()
+ log({"event": "eval", "iter": it, "val_loss": val, "lr": lr})
+ print(f"[{args.run_name}] iter {it:5d} val {val:.4f} lr {lr:.4g}")
+
+ if it == args.max_iters:
+ break
+
+ # Determine n_pred for batch fetch
+ # - n_pred_tokens > 1: multi-token MTP aux losses (each block predicts N targets via N heads)
+ # - progression_targets: each block l predicts next-(l+1) (so need n_pred = n_layer)
+ if args.n_pred_tokens > 1:
+ n_pred = args.n_pred_tokens
+ elif args.progression_targets:
+ n_pred = cfg.n_layer
+ else:
+ n_pred = 1
+
+ if n_pred > 1:
+ X, Y_multi = get_batch("train", data_dir, args.block_size, args.batch_size, device,
+ n_pred=n_pred)
+ Y = Y_multi[..., 0] # (B, T) — next-1 target for default
+ else:
+ X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device)
+ Y_multi = None
+
+ # ============================================================
+ # PC mode: predictive coding inference + Hebbian-style updates
+ # (when --pc_inference T > 0, replaces standard local CE)
+ # ============================================================
+ if args.pc_inference > 0:
+ optimizer.zero_grad()
+ Y_flat = Y.view(-1)
+
+ # 1. Forward init (no autograd graph during init)
+ with torch.no_grad():
+ init_acts = model.forward_activations(X)
+ # z[0] = embedding, clamped (no grad). z[1..L] = block outputs, evolve.
+ z = [init_acts[0].detach()]
+ for l in range(1, len(init_acts)):
+ z.append(init_acts[l].detach().clone().requires_grad_(True))
+
+ # 2. PC inference: T iterations of z updates via ∂F/∂z
+ # F = Σ_{l<L} (1/2) ||z_l - block_{l-1}(z_{l-1})||² / d + λ·CE(W_U @ z_L, y)
+ # Skip PE_L (Meta-PCN trick: CE replaces last-layer squared error)
+ # Use mean over hidden dim (per-token PE²) for scale-invariance.
+ for t in range(args.pc_inference):
+ F_energy = 0.0
+ for l in range(1, len(z) - 1): # l = 1..L-1, skip PE_L
+ z_hat = model.blocks[l - 1](z[l - 1])
+ pe = z[l] - z_hat
+ F_energy = F_energy + 0.5 * (pe ** 2).mean() # mean over (B,T,d) → scale-invariant
+ # Top-down: CE at z_L (replaces PE_L per Meta-PCN convention)
+ logits_top = F.linear(z[-1], model.head.weight)
+ CE_top = F.cross_entropy(logits_top.view(-1, vocab_size), Y_flat)
+ F_total = F_energy + args.pc_top_weight * CE_top
+ # Compute ∂F/∂z[1..L] (FA-flavored due to LocalLinear FA backward inside blocks)
+ grads = torch.autograd.grad(F_total, z[1:], create_graph=False, retain_graph=False)
+ # SGD update on z's
+ with torch.no_grad():
+ new_z = [z[0]]
+ for i, g in enumerate(grads):
+ new_z.append((z[i + 1] - args.pc_inference_lr * g).detach().requires_grad_(True))
+ z = new_z
+
+ # 3. Weight update via per-block PE loss using converged z's
+ # For block l-1: minimize ||sg(z_l) - block_{l-1}(sg(z_{l-1}))||²
+ # backward gives FA-flavored W gradients (Hebbian-equivalent at equilibrium)
+ total_loss = 0.0
+ for l in range(1, len(z)):
+ z_hat = model.blocks[l - 1](z[l - 1].detach())
+ target = z[l].detach()
+ pe_loss = 0.5 * ((target - z_hat) ** 2).mean()
+ pe_loss.backward()
+ total_loss += pe_loss.item()
+
+ # Final head + ln_f via CE on converged z[-1]
+ final_z = model.final_logits(z[-1].detach())
+ head_loss = F.cross_entropy(final_z.view(-1, vocab_size), Y_flat)
+ head_loss.backward()
+ total_loss += head_loss.item()
+
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+
+ # Optional WN
+ if args.weight_normalize:
+ with torch.no_grad():
+ for module in model.modules():
+ if isinstance(module, LocalLinear):
+ m, n = module.weight.shape
+ sigma_w = module.weight.std()
+ scale = (m ** 0.5 + n ** 0.5) * sigma_w
+ if scale > 1e-8:
+ module.weight.div_(scale)
+
+ if it % args.log_interval == 0:
+ log({"event": "step", "iter": it,
+ "total_loss": total_loss / (cfg.n_layer + 1),
+ "head_loss": head_loss.item(), "lr": lr})
+ continue
+ # ============================================================
+ # End PC mode; standard local CE follows
+ # ============================================================
+
+ # Forward: compute all activations
+ activations = model.forward_activations(X)
+
+ # Final logits (for KD teacher + eval)
+ with torch.no_grad():
+ final_logits = model.final_logits(activations[-1].detach())
+ teacher_probs = F.softmax(final_logits / args.kd_temp, dim=-1)
+
+ # Per-block local CE losses (no inter-block gradient)
+ optimizer.zero_grad()
+ total_loss = 0.0
+ Y_flat = Y.view(-1)
+
+ # Neighbor KL teachers are computed on-the-fly inside the per-block loop
+ # (avoid pre-computing all 6 × (B,T,V) tensors which OOM at V=50k)
+
+ for l in range(cfg.n_layer):
+ # Block l: h_l → block → h_{l+1}
+ h_l = activations[l] if l == 0 else activations[l].detach()
+ # For dfa_block mode, need h_lp1 with retain_grad to capture block-output-error
+ if args.intra_block_method == "dfa_block":
+ h_lp1 = model.blocks[l](h_l)
+ h_lp1.retain_grad()
+ else:
+ h_lp1 = model.blocks[l](h_l)
+
+ # Path I: progression targets — block l predicts next-(l+1) instead of next-1
+ if args.progression_targets and Y_multi is not None:
+ Y_block = Y_multi[..., l] # (B, T) — block l's specific target
+ else:
+ Y_block = Y # default: all blocks predict next-1
+ Y_block_flat = Y_block.reshape(-1)
+
+ # Local logits via shared unembedding (exact or BP-free)
+ if exit_heads is not None:
+ local_z = exit_heads[l](h_lp1, model.head.weight, Y_block)
+ else:
+ local_z = model.local_logits(h_lp1, l)
+ local_z_flat = local_z.view(-1, vocab_size)
+
+ # Per-layer weight
+ if args.layer_weighting == "linear":
+ layer_w = (l + 1) / cfg.n_layer
+ else:
+ layer_w = 1.0
+
+ # Ground-truth CE (uses Y_block_flat which respects progression mode)
+ loss_gt = F.cross_entropy(local_z_flat, Y_block_flat)
+
+ # KD from final layer (skip when both kd_weight and nbr_weight are 0 to save 3.3GB/block)
+ loss_kd = 0.0
+ loss_nbr = 0.0
+ if args.kd_weight > 0 or args.nbr_weight > 0:
+ local_log_probs = F.log_softmax(local_z / args.kd_temp, dim=-1)
+ if args.kd_weight > 0:
+ loss_kd = F.kl_div(
+ local_log_probs.view(-1, vocab_size),
+ teacher_probs.view(-1, vocab_size),
+ reduction="batchmean",
+ ) * (args.kd_temp ** 2)
+ # Neighbor KL: match next block's prediction (stop-grad), computed on-the-fly
+ if args.nbr_weight > 0 and l < cfg.n_layer - 1:
+ with torch.no_grad():
+ nbr_z = model.local_logits(activations[l + 2].detach(), l + 1)
+ nbr_probs = F.softmax(nbr_z / args.kd_temp, dim=-1)
+ del nbr_z
+ loss_nbr = F.kl_div(
+ local_log_probs.view(-1, vocab_size),
+ nbr_probs.view(-1, vocab_size),
+ reduction="batchmean",
+ ) * (args.kd_temp ** 2)
+ del nbr_probs
+ del local_log_probs
+
+ # Multi-token aux losses: predict next-2..next-N via aux_heads
+ # Each aux head provides an independent gradient direction (different W_k column space).
+ # Reuses the same exit_heads[l] (shared codebook) but with different shared_weight + targets.
+ loss_aux = 0.0
+ if args.n_pred_tokens > 1 and args.aux_weight > 0 and model.aux_heads is not None:
+ for k_idx, aux_head in enumerate(model.aux_heads):
+ Y_k = Y_multi[..., k_idx + 1] # next-(k_idx+2) target
+ if exit_heads is not None:
+ z_k = exit_heads[l](h_lp1, aux_head.weight, Y_k)
+ else:
+ z_k = F.linear(h_lp1, aux_head.weight)
+ loss_k = F.cross_entropy(z_k.view(-1, vocab_size), Y_k.reshape(-1))
+ loss_aux = loss_aux + loss_k
+ loss_aux = loss_aux * args.aux_weight / (args.n_pred_tokens - 1)
+
+ block_loss = layer_w * (
+ args.gt_weight * loss_gt
+ + args.kd_weight * loss_kd
+ + args.nbr_weight * loss_nbr
+ + loss_aux
+ )
+ block_loss.backward()
+
+ # For dfa_block: overwrite intra-block linears' .grad using block-output-error
+ if args.intra_block_method == "dfa_block" and h_lp1.grad is not None:
+ with torch.no_grad():
+ apply_dfa_block_update(model.blocks[l], h_lp1.grad)
+
+ total_loss += block_loss.item()
+
+ # Also train head + ln_f via final CE
+ h_L_det = activations[-1].detach()
+ final_z = model.final_logits(h_L_det)
+ head_loss = F.cross_entropy(final_z.view(-1, vocab_size), Y_flat)
+ head_loss.backward()
+ total_loss += head_loss.item()
+
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+
+ # GrAPE: per-step alignment of FA's B toward Jacobian via JVP probes (forward-only)
+ if args.fa_grape:
+ for module in model.modules():
+ if isinstance(module, LocalLinear) and getattr(module, "_fa_grape", False):
+ module.grape_align_step(lr_b=args.fa_grape_lr)
+
+ # Meta-PCN style weight normalization: rescale each LocalLinear's W to have ||W||_2 ~= 1
+ # via random matrix theory bound ||W||_2 ~= (sqrt(m) + sqrt(n)) * std(W).
+ # Only normalizes LocalLinear W (the trained weight); leaves B (fixed buffer) untouched.
+ if args.weight_normalize:
+ with torch.no_grad():
+ for module in model.modules():
+ if isinstance(module, LocalLinear):
+ m, n = module.weight.shape
+ sigma_w = module.weight.std()
+ scale = (m ** 0.5 + n ** 0.5) * sigma_w
+ if scale > 1e-8:
+ module.weight.div_(scale)
+
+ if it % args.log_interval == 0:
+ log({"event": "step", "iter": it, "total_loss": total_loss / (cfg.n_layer + 1),
+ "head_loss": head_loss.item(), "lr": lr})
+
+ if args.save_ckpt:
+ ckpt_path = run_dir / "ckpt.pt"
+ torch.save({
+ "model_state": model.state_dict(),
+ "config": vars(cfg),
+ "args": vars(args),
+ "vocab_size": vocab_size,
+ }, ckpt_path)
+ log({"event": "save_ckpt", "path": str(ckpt_path)})
+ print(f"[{args.run_name}] saved ckpt to {ckpt_path}")
+
+
+if __name__ == "__main__":
+ main()