"""Local CE exit training — each block gets a vocab-space CE loss via shared unembedding. Each block l computes: z_l = W_U @ T_l(h_l) (local logits via shared unembedding + optional translator) L_l = λ_gt * CE(z_l, y) + λ_kd * τ² * KL(sg(p_L^τ) || p_l^τ) Forward weights updated per-block via local CE gradient (intra-block only). No inter-block chain rule. Fused attention backward within each block. This replaces the hidden-space MSE target-matching that failed at scale. """ import argparse import json import math import pickle import time from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from model_local import LocalGPTConfig, LocalBlock, LocalLinear, _make_ln from factorized_exit import FactorizedExitHead, ExactParallelExitHead from local_layers import initialize_dfa_block_targets, apply_dfa_block_update def get_batch(split, data_dir, block_size, batch_size, device, n_pred=1): fn = "train.bin" if split == "train" else "val.bin" data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r") ix = torch.randint(len(data) - block_size - n_pred, (batch_size,)) x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) if n_pred == 1: y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix]) return x.to(device, non_blocking=True), y.to(device, non_blocking=True) # n_pred > 1: targets shape (B, T, n_pred). Y[..., k-1] = next-k target. y_multi = torch.stack([ torch.stack([ torch.from_numpy(data[i + k : i + k + block_size].astype(np.int64)) for k in range(1, n_pred + 1) ], dim=-1) for i in ix ]) return x.to(device, non_blocking=True), y_multi.to(device, non_blocking=True) class LowRankTranslator(nn.Module): """T_l(h) = h + A @ B @ h + b. Low-rank affine residual translator.""" def __init__(self, d_model, rank=32): super().__init__() self.A = nn.Parameter(torch.zeros(d_model, rank)) self.B = nn.Parameter(torch.zeros(rank, d_model)) self.bias = nn.Parameter(torch.zeros(d_model)) nn.init.normal_(self.A, std=0.01) nn.init.normal_(self.B, std=0.01) def forward(self, h): return h + h @ self.B.T @ self.A.T + self.bias class LocalCETransformer(nn.Module): """Transformer with per-block local CE exits via shared unembedding.""" def __init__(self, config: LocalGPTConfig, translator_rank: int = 0, n_pred_tokens: int = 1, shared_blocks: bool = False): super().__init__() self.config = config self.n_pred_tokens = n_pred_tokens self.shared_blocks = shared_blocks self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) self.pos_emb = nn.Embedding(config.block_size, config.n_embd) self.drop = nn.Dropout(config.dropout) if shared_blocks: # Universal Transformer: one block applied n_layer times. # All entries point to the SAME module — gradient accumulates from all "depths". shared = LocalBlock(config) self.blocks = nn.ModuleList([shared for _ in range(config.n_layer)]) else: self.blocks = nn.ModuleList([LocalBlock(config) for _ in range(config.n_layer)]) self.ln_f = _make_ln(config) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Auxiliary unembedding heads for next-2..next-N prediction (multi-token training). # Used only as gradient-source heads at training time; inference still uses self.head. if n_pred_tokens > 1: self.aux_heads = nn.ModuleList([ nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(n_pred_tokens - 1) ]) else: self.aux_heads = None # Per-block translators (logit lens = rank 0 = identity) if translator_rank > 0: self.translators = nn.ModuleList([ LowRankTranslator(config.n_embd, translator_rank) for _ in range(config.n_layer) ]) else: self.translators = None self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith("o_proj.weight") or pn.endswith("mlp.proj.weight"): nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) def _init_weights(self, m): if isinstance(m, (nn.Linear, LocalLinear)): nn.init.normal_(m.weight, mean=0.0, std=0.02) if getattr(m, "bias", None) is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) def forward_activations(self, idx): B, T = idx.shape pos = torch.arange(T, device=idx.device) h = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) activations = [h] for block in self.blocks: h = block(h) activations.append(h) return activations def local_logits(self, h, layer_idx): """h → local logits via optional translator + shared unembedding.""" if self.translators is not None: h = self.translators[layer_idx](h) return F.linear(h, self.head.weight) # shared W_U, no separate head def final_logits(self, h): return self.head(self.ln_f(h)) def main(): p = argparse.ArgumentParser() p.add_argument("--run_name", type=str, required=True) p.add_argument("--seed", type=int, default=1337) p.add_argument("--data_dir", type=str, default="data/shakespeare_char") p.add_argument("--out_dir", type=str, default="runs_local") p.add_argument("--block_size", type=int, default=256) p.add_argument("--batch_size", type=int, default=64) p.add_argument("--n_layer", type=int, default=6) p.add_argument("--n_head", type=int, default=6) p.add_argument("--n_embd", type=int, default=384) p.add_argument("--dropout", type=float, default=0.2) p.add_argument("--max_iters", type=int, default=5000) p.add_argument("--warmup_iters", type=int, default=100) p.add_argument("--max_lr", type=float, default=1e-3) p.add_argument("--min_lr", type=float, default=1e-4) p.add_argument("--eval_interval", type=int, default=250) p.add_argument("--eval_iters", type=int, default=100) p.add_argument("--log_interval", type=int, default=50) p.add_argument("--attn_mode", type=str, default="softmax") p.add_argument("--translator_rank", type=int, default=0, help="0=identity (logit lens), >0=low-rank affine") p.add_argument("--kd_weight", type=float, default=1.0, help="weight for KL distillation from final layer") p.add_argument("--kd_temp", type=float, default=2.0, help="temperature for KD") p.add_argument("--gt_weight", type=float, default=1.0, help="weight for ground-truth CE") p.add_argument("--nbr_weight", type=float, default=0.0, help="weight for neighbor KL (sg(p_{l+1}) || p_l)") p.add_argument("--layer_weighting", type=str, default="uniform", choices=["uniform", "linear"], help="per-layer loss weight: uniform=all 1.0, linear=l/L") p.add_argument("--bp_free_exit", type=str, default="none", choices=["none", "dense", "hybrid", "parallel_only", "parallel_gold", "parallel_topmass"], help="BP-free exit: none=W_U^T, dense/hybrid=compressor, parallel_*=exact parallel term") p.add_argument("--exit_rank", type=int, default=128, help="rank for BP-free exit compressor") p.add_argument("--exit_rank_exact", type=int, default=32, help="exact rank for hybrid compressor") p.add_argument("--exit_topk", type=int, default=8, help="top-k for hybrid compressor") p.add_argument("--exit_residual_rank", type=int, default=32, help="residual_rank for ExactParallelExitHead (parallel_gold/topmass): code dim for h-perp residual") p.add_argument("--intra_block_method", type=str, default="bp", choices=["bp", "fa", "sign_sym", "dfa_block"], help="intra-block: bp=W^T, fa=seq random B, sign_sym=sign(W)·rescale, dfa_block=direct from block-output-error") p.add_argument("--mlp_topk", type=int, default=0, help="if >0, apply hard k-WTA to MLP hidden activation (4*n_embd dim)") p.add_argument("--resid_topk", type=int, default=0, help="if >0, apply hard k-WTA to residual stream output of each block (n_embd dim)") p.add_argument("--vq_codes", type=int, default=0, help="if >0, apply directional VQ to residual stream at each block (K codebook entries, frozen)") p.add_argument("--subspace_rank", type=int, default=0, help="if >0, project residual stream to fixed r-dim orthonormal subspace at each block") p.add_argument("--subspace_per_layer", action="store_true", help="use DIFFERENT random Q per layer (ablation: tests if shared Q is necessary)") p.add_argument("--fa_init_sign", action="store_true", help="init FA's fixed B as sign(W_init)*rescale instead of random (frozen sign_sym)") p.add_argument("--shared_blocks", action="store_true", help="Universal Transformer: all blocks share the same parameters (single block applied n_layer times)") p.add_argument("--fa_init", type=str, default="gaussian", choices=["gaussian", "orthogonal", "ortho_he", "sparse"], help="FA's fixed B init mode (gaussian=Lillicrap, orthogonal=JL-isometric, ortho_he=He-init backward, sparse=structured)") p.add_argument("--fa_sparse_k", type=int, default=0, help="for fa_init=sparse: non-zero entries per row (0 = auto = in_features/16)") p.add_argument("--gated_blocks", action="store_true", help="Path IV: learned per-block residual gates (α_attn, α_mlp). Lets useless layers self-deactivate.") p.add_argument("--progression_targets", action="store_true", help="Path I: each block l predicts next-(l+1) token (progressive prediction horizons per layer)") p.add_argument("--weight_normalize", action="store_true", help="Meta-PCN style WN: after each optimizer step, normalize LocalLinear's W by (sqrt(m)+sqrt(n))*std(W) to keep ||W||_2 ~= 1") p.add_argument("--pc_inference", type=int, default=0, help="Predictive coding inference steps T (T=0 disables PC mode, uses standard local CE)") p.add_argument("--pc_inference_lr", type=float, default=0.1, help="Inference step size η for PC z updates") p.add_argument("--pc_top_weight", type=float, default=1.0, help="Weight of top-down CE term in PC energy F") p.add_argument("--fa_grape", action="store_true", help="GrAPE: per-step JVP-based cosine alignment of FA's B toward true Jacobian (Caillon et al. 2026)") p.add_argument("--fa_grape_lr", type=float, default=0.01, help="Learning rate for GrAPE B alignment update") p.add_argument("--fa_grape_n_probe", type=int, default=32, help="Number of probe samples for JVP rank-1 Jacobian estimate") p.add_argument("--save_ckpt", action="store_true", help="save final model state to run_dir/ckpt.pt for downstream probing") p.add_argument("--n_pred_tokens", type=int, default=1, help="multi-token prediction: predict next-1..next-N (N=1 disables, default)") p.add_argument("--aux_weight", type=float, default=0.3, help="weight for aux next-k losses (k=2..N). Primary next-1 always weight 1.0.") args = p.parse_args() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) device = "cuda" if torch.cuda.is_available() else "cpu" data_dir = Path(args.data_dir) with open(data_dir / "meta.pkl", "rb") as f: meta = pickle.load(f) vocab_size = meta["vocab_size"] run_dir = Path(args.out_dir) / args.run_name run_dir.mkdir(parents=True, exist_ok=True) log_path = run_dir / "log.jsonl" log_path.write_text("") cfg = LocalGPTConfig( block_size=args.block_size, vocab_size=vocab_size, n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, dropout=args.dropout, attn_mode=args.attn_mode, method=args.intra_block_method, fuse_attn_local=True, ste_gelu=True, ln_mode="center_scale", mlp_topk=args.mlp_topk, resid_topk=args.resid_topk, vq_codes=args.vq_codes, subspace_rank=args.subspace_rank, fa_init_mode=args.fa_init, fa_sparse_k=args.fa_sparse_k, gated_blocks=args.gated_blocks, fa_grape=args.fa_grape, fa_grape_n_probe=args.fa_grape_n_probe, ) model = LocalCETransformer(cfg, translator_rank=args.translator_rank, n_pred_tokens=args.n_pred_tokens, shared_blocks=args.shared_blocks).to(device) # Frozen sign_sym: replace FA's random B with sign(W_init)*rescale, then freeze. # B is still a fixed buffer (BP-free by definition B), just structured init. if args.fa_init_sign and args.intra_block_method == "fa": with torch.no_grad(): for module in model.modules(): if isinstance(module, LocalLinear) and module.method == "fa": scale = module.weight.norm() / (module.weight.numel() ** 0.5 + 1e-8) module.B.copy_(torch.sign(module.weight) * scale) # Per-layer different Q ablation: replace each block's shared-seed subspace # with independently-seeded subspace (tests if shared Q is the mechanism) if args.subspace_per_layer and args.subspace_rank > 0: from model_local import FrozenSubspace for i, block in enumerate(model.blocks): block.subspace = FrozenSubspace(args.n_embd, args.subspace_rank, seed=1000 + i).to(device) # Initialize DFA-block targets if needed if args.intra_block_method == "dfa_block": initialize_dfa_block_targets(model, args.n_embd) # BP-free exit heads (one per block) exit_heads = None if args.bp_free_exit in ("dense", "hybrid"): exit_heads = nn.ModuleList([ FactorizedExitHead( args.n_embd, vocab_size, mode=args.bp_free_exit, rank=args.exit_rank, rank_exact=args.exit_rank_exact, topk=args.exit_topk, ) for _ in range(cfg.n_layer) ]).to(device) elif args.bp_free_exit.startswith("parallel"): exit_heads = nn.ModuleList([ ExactParallelExitHead( args.n_embd, vocab_size, mode=args.bp_free_exit, residual_rank=args.exit_residual_rank, ) for _ in range(cfg.n_layer) ]).to(device) n_params = sum(p.numel() for p in model.parameters()) optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr, weight_decay=0.1) t0 = time.time() def log(rec): rec["t"] = time.time() - t0 with open(log_path, "a") as f: f.write(json.dumps(rec) + "\n") log({"event": "start", "method": "local_ce", "params": n_params, "translator_rank": args.translator_rank, "config": vars(args)}) print(f"[{args.run_name}] local_ce, params={n_params/1e6:.2f}M, translator_rank={args.translator_rank}") def lr_schedule(it): if it < args.warmup_iters: return args.max_lr * (it + 1) / (args.warmup_iters + 1) decay = 0.5 * (1 + math.cos(math.pi * (it - args.warmup_iters) / max(1, args.max_iters - args.warmup_iters))) return args.min_lr + decay * (args.max_lr - args.min_lr) @torch.no_grad() def eval_loss(): model.eval() losses = torch.zeros(args.eval_iters) for k in range(args.eval_iters): X, Y = get_batch("val", data_dir, args.block_size, args.batch_size, device) acts = model.forward_activations(X) logits = model.final_logits(acts[-1]) loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1)) losses[k] = loss.item() model.train() return losses.mean().item() model.train() for it in range(args.max_iters + 1): lr = lr_schedule(it) for g in optimizer.param_groups: g["lr"] = lr if it % args.eval_interval == 0 or it == args.max_iters: val = eval_loss() log({"event": "eval", "iter": it, "val_loss": val, "lr": lr}) print(f"[{args.run_name}] iter {it:5d} val {val:.4f} lr {lr:.4g}") if it == args.max_iters: break # Determine n_pred for batch fetch # - n_pred_tokens > 1: multi-token MTP aux losses (each block predicts N targets via N heads) # - progression_targets: each block l predicts next-(l+1) (so need n_pred = n_layer) if args.n_pred_tokens > 1: n_pred = args.n_pred_tokens elif args.progression_targets: n_pred = cfg.n_layer else: n_pred = 1 if n_pred > 1: X, Y_multi = get_batch("train", data_dir, args.block_size, args.batch_size, device, n_pred=n_pred) Y = Y_multi[..., 0] # (B, T) — next-1 target for default else: X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) Y_multi = None # ============================================================ # PC mode: predictive coding inference + Hebbian-style updates # (when --pc_inference T > 0, replaces standard local CE) # ============================================================ if args.pc_inference > 0: optimizer.zero_grad() Y_flat = Y.view(-1) # 1. Forward init (no autograd graph during init) with torch.no_grad(): init_acts = model.forward_activations(X) # z[0] = embedding, clamped (no grad). z[1..L] = block outputs, evolve. z = [init_acts[0].detach()] for l in range(1, len(init_acts)): z.append(init_acts[l].detach().clone().requires_grad_(True)) # 2. PC inference: T iterations of z updates via ∂F/∂z # F = Σ_{l