From b83947778e2c776f757a07d4719b7ce961d7ed55 Mon Sep 17 00:00:00 2001 From: Yuren Hao Date: Fri, 3 Jul 2026 05:56:50 -0500 Subject: =?UTF-8?q?Initial=20commit:=20ept=20=E2=80=94=20backprop-free=20e?= =?UTF-8?q?quilibrium=20transformer=20(EP)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn --- ep_run/train_local_ce.py | 580 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 580 insertions(+) create mode 100644 ep_run/train_local_ce.py (limited to 'ep_run/train_local_ce.py') diff --git a/ep_run/train_local_ce.py b/ep_run/train_local_ce.py new file mode 100644 index 0000000..b8f790a --- /dev/null +++ b/ep_run/train_local_ce.py @@ -0,0 +1,580 @@ +"""Local CE exit training — each block gets a vocab-space CE loss via shared unembedding. + +Each block l computes: + z_l = W_U @ T_l(h_l) (local logits via shared unembedding + optional translator) + L_l = λ_gt * CE(z_l, y) + λ_kd * τ² * KL(sg(p_L^τ) || p_l^τ) + +Forward weights updated per-block via local CE gradient (intra-block only). +No inter-block chain rule. Fused attention backward within each block. + +This replaces the hidden-space MSE target-matching that failed at scale. +""" +import argparse +import json +import math +import pickle +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model_local import LocalGPTConfig, LocalBlock, LocalLinear, _make_ln +from factorized_exit import FactorizedExitHead, ExactParallelExitHead +from local_layers import initialize_dfa_block_targets, apply_dfa_block_update + + +def get_batch(split, data_dir, block_size, batch_size, device, n_pred=1): + fn = "train.bin" if split == "train" else "val.bin" + data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r") + ix = torch.randint(len(data) - block_size - n_pred, (batch_size,)) + x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) + if n_pred == 1: + y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix]) + return x.to(device, non_blocking=True), y.to(device, non_blocking=True) + # n_pred > 1: targets shape (B, T, n_pred). Y[..., k-1] = next-k target. + y_multi = torch.stack([ + torch.stack([ + torch.from_numpy(data[i + k : i + k + block_size].astype(np.int64)) + for k in range(1, n_pred + 1) + ], dim=-1) + for i in ix + ]) + return x.to(device, non_blocking=True), y_multi.to(device, non_blocking=True) + + +class LowRankTranslator(nn.Module): + """T_l(h) = h + A @ B @ h + b. Low-rank affine residual translator.""" + def __init__(self, d_model, rank=32): + super().__init__() + self.A = nn.Parameter(torch.zeros(d_model, rank)) + self.B = nn.Parameter(torch.zeros(rank, d_model)) + self.bias = nn.Parameter(torch.zeros(d_model)) + nn.init.normal_(self.A, std=0.01) + nn.init.normal_(self.B, std=0.01) + + def forward(self, h): + return h + h @ self.B.T @ self.A.T + self.bias + + +class LocalCETransformer(nn.Module): + """Transformer with per-block local CE exits via shared unembedding.""" + + def __init__(self, config: LocalGPTConfig, translator_rank: int = 0, n_pred_tokens: int = 1, + shared_blocks: bool = False): + super().__init__() + self.config = config + self.n_pred_tokens = n_pred_tokens + self.shared_blocks = shared_blocks + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + self.pos_emb = nn.Embedding(config.block_size, config.n_embd) + self.drop = nn.Dropout(config.dropout) + if shared_blocks: + # Universal Transformer: one block applied n_layer times. + # All entries point to the SAME module — gradient accumulates from all "depths". + shared = LocalBlock(config) + self.blocks = nn.ModuleList([shared for _ in range(config.n_layer)]) + else: + self.blocks = nn.ModuleList([LocalBlock(config) for _ in range(config.n_layer)]) + self.ln_f = _make_ln(config) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Auxiliary unembedding heads for next-2..next-N prediction (multi-token training). + # Used only as gradient-source heads at training time; inference still uses self.head. + if n_pred_tokens > 1: + self.aux_heads = nn.ModuleList([ + nn.Linear(config.n_embd, config.vocab_size, bias=False) + for _ in range(n_pred_tokens - 1) + ]) + else: + self.aux_heads = None + + # Per-block translators (logit lens = rank 0 = identity) + if translator_rank > 0: + self.translators = nn.ModuleList([ + LowRankTranslator(config.n_embd, translator_rank) + for _ in range(config.n_layer) + ]) + else: + self.translators = None + + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith("o_proj.weight") or pn.endswith("mlp.proj.weight"): + nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) + + def _init_weights(self, m): + if isinstance(m, (nn.Linear, LocalLinear)): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + if getattr(m, "bias", None) is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + + def forward_activations(self, idx): + B, T = idx.shape + pos = torch.arange(T, device=idx.device) + h = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) + activations = [h] + for block in self.blocks: + h = block(h) + activations.append(h) + return activations + + def local_logits(self, h, layer_idx): + """h → local logits via optional translator + shared unembedding.""" + if self.translators is not None: + h = self.translators[layer_idx](h) + return F.linear(h, self.head.weight) # shared W_U, no separate head + + def final_logits(self, h): + return self.head(self.ln_f(h)) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--run_name", type=str, required=True) + p.add_argument("--seed", type=int, default=1337) + p.add_argument("--data_dir", type=str, default="data/shakespeare_char") + p.add_argument("--out_dir", type=str, default="runs_local") + p.add_argument("--block_size", type=int, default=256) + p.add_argument("--batch_size", type=int, default=64) + p.add_argument("--n_layer", type=int, default=6) + p.add_argument("--n_head", type=int, default=6) + p.add_argument("--n_embd", type=int, default=384) + p.add_argument("--dropout", type=float, default=0.2) + p.add_argument("--max_iters", type=int, default=5000) + p.add_argument("--warmup_iters", type=int, default=100) + p.add_argument("--max_lr", type=float, default=1e-3) + p.add_argument("--min_lr", type=float, default=1e-4) + p.add_argument("--eval_interval", type=int, default=250) + p.add_argument("--eval_iters", type=int, default=100) + p.add_argument("--log_interval", type=int, default=50) + p.add_argument("--attn_mode", type=str, default="softmax") + p.add_argument("--translator_rank", type=int, default=0, help="0=identity (logit lens), >0=low-rank affine") + p.add_argument("--kd_weight", type=float, default=1.0, help="weight for KL distillation from final layer") + p.add_argument("--kd_temp", type=float, default=2.0, help="temperature for KD") + p.add_argument("--gt_weight", type=float, default=1.0, help="weight for ground-truth CE") + p.add_argument("--nbr_weight", type=float, default=0.0, help="weight for neighbor KL (sg(p_{l+1}) || p_l)") + p.add_argument("--layer_weighting", type=str, default="uniform", choices=["uniform", "linear"], + help="per-layer loss weight: uniform=all 1.0, linear=l/L") + p.add_argument("--bp_free_exit", type=str, default="none", + choices=["none", "dense", "hybrid", "parallel_only", "parallel_gold", "parallel_topmass"], + help="BP-free exit: none=W_U^T, dense/hybrid=compressor, parallel_*=exact parallel term") + p.add_argument("--exit_rank", type=int, default=128, help="rank for BP-free exit compressor") + p.add_argument("--exit_rank_exact", type=int, default=32, help="exact rank for hybrid compressor") + p.add_argument("--exit_topk", type=int, default=8, help="top-k for hybrid compressor") + p.add_argument("--exit_residual_rank", type=int, default=32, + help="residual_rank for ExactParallelExitHead (parallel_gold/topmass): code dim for h-perp residual") + p.add_argument("--intra_block_method", type=str, default="bp", choices=["bp", "fa", "sign_sym", "dfa_block"], + help="intra-block: bp=W^T, fa=seq random B, sign_sym=sign(W)·rescale, dfa_block=direct from block-output-error") + p.add_argument("--mlp_topk", type=int, default=0, + help="if >0, apply hard k-WTA to MLP hidden activation (4*n_embd dim)") + p.add_argument("--resid_topk", type=int, default=0, + help="if >0, apply hard k-WTA to residual stream output of each block (n_embd dim)") + p.add_argument("--vq_codes", type=int, default=0, + help="if >0, apply directional VQ to residual stream at each block (K codebook entries, frozen)") + p.add_argument("--subspace_rank", type=int, default=0, + help="if >0, project residual stream to fixed r-dim orthonormal subspace at each block") + p.add_argument("--subspace_per_layer", action="store_true", + help="use DIFFERENT random Q per layer (ablation: tests if shared Q is necessary)") + p.add_argument("--fa_init_sign", action="store_true", + help="init FA's fixed B as sign(W_init)*rescale instead of random (frozen sign_sym)") + p.add_argument("--shared_blocks", action="store_true", + help="Universal Transformer: all blocks share the same parameters (single block applied n_layer times)") + p.add_argument("--fa_init", type=str, default="gaussian", + choices=["gaussian", "orthogonal", "ortho_he", "sparse"], + help="FA's fixed B init mode (gaussian=Lillicrap, orthogonal=JL-isometric, ortho_he=He-init backward, sparse=structured)") + p.add_argument("--fa_sparse_k", type=int, default=0, + help="for fa_init=sparse: non-zero entries per row (0 = auto = in_features/16)") + p.add_argument("--gated_blocks", action="store_true", + help="Path IV: learned per-block residual gates (α_attn, α_mlp). Lets useless layers self-deactivate.") + p.add_argument("--progression_targets", action="store_true", + help="Path I: each block l predicts next-(l+1) token (progressive prediction horizons per layer)") + p.add_argument("--weight_normalize", action="store_true", + help="Meta-PCN style WN: after each optimizer step, normalize LocalLinear's W by (sqrt(m)+sqrt(n))*std(W) to keep ||W||_2 ~= 1") + p.add_argument("--pc_inference", type=int, default=0, + help="Predictive coding inference steps T (T=0 disables PC mode, uses standard local CE)") + p.add_argument("--pc_inference_lr", type=float, default=0.1, + help="Inference step size η for PC z updates") + p.add_argument("--pc_top_weight", type=float, default=1.0, + help="Weight of top-down CE term in PC energy F") + p.add_argument("--fa_grape", action="store_true", + help="GrAPE: per-step JVP-based cosine alignment of FA's B toward true Jacobian (Caillon et al. 2026)") + p.add_argument("--fa_grape_lr", type=float, default=0.01, + help="Learning rate for GrAPE B alignment update") + p.add_argument("--fa_grape_n_probe", type=int, default=32, + help="Number of probe samples for JVP rank-1 Jacobian estimate") + p.add_argument("--save_ckpt", action="store_true", + help="save final model state to run_dir/ckpt.pt for downstream probing") + p.add_argument("--n_pred_tokens", type=int, default=1, + help="multi-token prediction: predict next-1..next-N (N=1 disables, default)") + p.add_argument("--aux_weight", type=float, default=0.3, + help="weight for aux next-k losses (k=2..N). Primary next-1 always weight 1.0.") + args = p.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + data_dir = Path(args.data_dir) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + vocab_size = meta["vocab_size"] + + run_dir = Path(args.out_dir) / args.run_name + run_dir.mkdir(parents=True, exist_ok=True) + log_path = run_dir / "log.jsonl" + log_path.write_text("") + + cfg = LocalGPTConfig( + block_size=args.block_size, vocab_size=vocab_size, + n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, + dropout=args.dropout, attn_mode=args.attn_mode, + method=args.intra_block_method, + fuse_attn_local=True, ste_gelu=True, ln_mode="center_scale", + mlp_topk=args.mlp_topk, resid_topk=args.resid_topk, + vq_codes=args.vq_codes, subspace_rank=args.subspace_rank, + fa_init_mode=args.fa_init, fa_sparse_k=args.fa_sparse_k, + gated_blocks=args.gated_blocks, + fa_grape=args.fa_grape, fa_grape_n_probe=args.fa_grape_n_probe, + ) + model = LocalCETransformer(cfg, translator_rank=args.translator_rank, + n_pred_tokens=args.n_pred_tokens, + shared_blocks=args.shared_blocks).to(device) + + # Frozen sign_sym: replace FA's random B with sign(W_init)*rescale, then freeze. + # B is still a fixed buffer (BP-free by definition B), just structured init. + if args.fa_init_sign and args.intra_block_method == "fa": + with torch.no_grad(): + for module in model.modules(): + if isinstance(module, LocalLinear) and module.method == "fa": + scale = module.weight.norm() / (module.weight.numel() ** 0.5 + 1e-8) + module.B.copy_(torch.sign(module.weight) * scale) + + # Per-layer different Q ablation: replace each block's shared-seed subspace + # with independently-seeded subspace (tests if shared Q is the mechanism) + if args.subspace_per_layer and args.subspace_rank > 0: + from model_local import FrozenSubspace + for i, block in enumerate(model.blocks): + block.subspace = FrozenSubspace(args.n_embd, args.subspace_rank, seed=1000 + i).to(device) + + # Initialize DFA-block targets if needed + if args.intra_block_method == "dfa_block": + initialize_dfa_block_targets(model, args.n_embd) + + # BP-free exit heads (one per block) + exit_heads = None + if args.bp_free_exit in ("dense", "hybrid"): + exit_heads = nn.ModuleList([ + FactorizedExitHead( + args.n_embd, vocab_size, mode=args.bp_free_exit, + rank=args.exit_rank, rank_exact=args.exit_rank_exact, topk=args.exit_topk, + ) for _ in range(cfg.n_layer) + ]).to(device) + elif args.bp_free_exit.startswith("parallel"): + exit_heads = nn.ModuleList([ + ExactParallelExitHead( + args.n_embd, vocab_size, mode=args.bp_free_exit, + residual_rank=args.exit_residual_rank, + ) for _ in range(cfg.n_layer) + ]).to(device) + + n_params = sum(p.numel() for p in model.parameters()) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr, weight_decay=0.1) + + t0 = time.time() + + def log(rec): + rec["t"] = time.time() - t0 + with open(log_path, "a") as f: + f.write(json.dumps(rec) + "\n") + + log({"event": "start", "method": "local_ce", "params": n_params, + "translator_rank": args.translator_rank, "config": vars(args)}) + print(f"[{args.run_name}] local_ce, params={n_params/1e6:.2f}M, translator_rank={args.translator_rank}") + + def lr_schedule(it): + if it < args.warmup_iters: + return args.max_lr * (it + 1) / (args.warmup_iters + 1) + decay = 0.5 * (1 + math.cos(math.pi * (it - args.warmup_iters) / + max(1, args.max_iters - args.warmup_iters))) + return args.min_lr + decay * (args.max_lr - args.min_lr) + + @torch.no_grad() + def eval_loss(): + model.eval() + losses = torch.zeros(args.eval_iters) + for k in range(args.eval_iters): + X, Y = get_batch("val", data_dir, args.block_size, args.batch_size, device) + acts = model.forward_activations(X) + logits = model.final_logits(acts[-1]) + loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1)) + losses[k] = loss.item() + model.train() + return losses.mean().item() + + model.train() + for it in range(args.max_iters + 1): + lr = lr_schedule(it) + for g in optimizer.param_groups: + g["lr"] = lr + + if it % args.eval_interval == 0 or it == args.max_iters: + val = eval_loss() + log({"event": "eval", "iter": it, "val_loss": val, "lr": lr}) + print(f"[{args.run_name}] iter {it:5d} val {val:.4f} lr {lr:.4g}") + + if it == args.max_iters: + break + + # Determine n_pred for batch fetch + # - n_pred_tokens > 1: multi-token MTP aux losses (each block predicts N targets via N heads) + # - progression_targets: each block l predicts next-(l+1) (so need n_pred = n_layer) + if args.n_pred_tokens > 1: + n_pred = args.n_pred_tokens + elif args.progression_targets: + n_pred = cfg.n_layer + else: + n_pred = 1 + + if n_pred > 1: + X, Y_multi = get_batch("train", data_dir, args.block_size, args.batch_size, device, + n_pred=n_pred) + Y = Y_multi[..., 0] # (B, T) — next-1 target for default + else: + X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) + Y_multi = None + + # ============================================================ + # PC mode: predictive coding inference + Hebbian-style updates + # (when --pc_inference T > 0, replaces standard local CE) + # ============================================================ + if args.pc_inference > 0: + optimizer.zero_grad() + Y_flat = Y.view(-1) + + # 1. Forward init (no autograd graph during init) + with torch.no_grad(): + init_acts = model.forward_activations(X) + # z[0] = embedding, clamped (no grad). z[1..L] = block outputs, evolve. + z = [init_acts[0].detach()] + for l in range(1, len(init_acts)): + z.append(init_acts[l].detach().clone().requires_grad_(True)) + + # 2. PC inference: T iterations of z updates via ∂F/∂z + # F = Σ_{l