diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/factorized_exit.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/factorized_exit.py')
| -rw-r--r-- | ep_run/factorized_exit.py | 330 |
1 files changed, 330 insertions, 0 deletions
diff --git a/ep_run/factorized_exit.py b/ep_run/factorized_exit.py new file mode 100644 index 0000000..fbf66a8 --- /dev/null +++ b/ep_run/factorized_exit.py @@ -0,0 +1,330 @@ +"""Factorized BP-free exit feedback for local CE training. + +Replaces W_U^T(p-y) with α · C(p-y) @ U^T where: + C: fixed compressor (dense random or hybrid gold+topk+tail-sketch) + U: fixed orthonormal expander (d, r) + α: scalar gain + +Forward logits = h @ W_U^T (exact, unchanged) +grad_W = exact local CE gradient (no weight transport) +grad_h = factorized BP-free signal (no W_U^T) + +Two compressor modes: + dense: g @ C where C is (V, r) fixed random + hybrid: [gold + top-k exact codes, CountSketch(tail)] +""" +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def orthonormal_columns(d_out, rank, *, device=None, dtype=torch.float32, seed=None): + if rank <= 0 or rank > d_out: + raise ValueError(f"rank must satisfy 1 <= rank <= d_out, got rank={rank}, d_out={d_out}") + gen = torch.Generator(device="cpu") + if seed is not None: + gen.manual_seed(seed) + q, _ = torch.linalg.qr( + torch.randn(d_out, rank, dtype=dtype, generator=gen), mode="reduced" + ) + return q.contiguous().to(device=device) + + +class DenseRandomCompressor(nn.Module): + def __init__(self, vocab_size, rank, *, seed=None): + super().__init__() + self.vocab_size = vocab_size + self._rank = rank + gen = torch.Generator() + if seed is not None: + gen.manual_seed(seed) + codebook = torch.randn(vocab_size, rank, generator=gen) / math.sqrt(rank) + self.register_buffer("codebook", codebook) + + @property + def rank(self): + return self._rank + + @torch.no_grad() + def compress(self, grad_logits, targets): + return grad_logits.float() @ self.codebook.float() + + +class HybridTopKTailSketchCompressor(nn.Module): + def __init__(self, vocab_size, *, rank_exact=32, rank_tail=96, topk=8, seed=None): + super().__init__() + self.vocab_size = vocab_size + self.rank_exact = rank_exact + self.rank_tail = rank_tail + self.topk = topk + + gen = torch.Generator() + if seed is not None: + gen.manual_seed(seed) + + if rank_exact > 0: + codes = torch.randn(vocab_size, rank_exact, generator=gen) + codes = codes / codes.norm(dim=-1, keepdim=True).clamp_min(1e-6) + else: + codes = torch.empty(vocab_size, 0) + self.register_buffer("exact_codes", codes) + + if rank_tail > 0: + bucket = torch.randint(0, rank_tail, (vocab_size,), generator=gen) + sign = torch.randint(0, 2, (vocab_size,), generator=gen).float() * 2 - 1 + else: + bucket = torch.empty(vocab_size, dtype=torch.long) + sign = torch.empty(vocab_size) + self.register_buffer("bucket", bucket) + self.register_buffer("sign", sign) + + @property + def rank(self): + return self.rank_exact + self.rank_tail + + @torch.no_grad() + def compress(self, grad_logits, targets): + g = grad_logits.float() + V = g.size(-1) + orig_shape = g.shape[:-1] + g_flat = g.reshape(-1, V) + t_flat = targets.reshape(-1).long() + N = g_flat.size(0) + device = g_flat.device + + safe_t = t_flat.clamp(min=0) + gold_grad = g_flat.gather(1, safe_t.unsqueeze(1)) # (N, 1) + + k_eff = min(self.topk, max(V - 1, 0)) + parts = [] + + # Exact head: gold + top-k + if self.rank_exact > 0: + codes = self.exact_codes.float() + gold_codes = codes[safe_t] # (N, r_exact) + c_exact = gold_grad * gold_codes + if k_eff > 0: + topv, topi = g_flat.topk(k_eff, dim=1) + top_codes = codes[topi] # (N, k, r_exact) + c_exact = c_exact + (topv.unsqueeze(-1) * top_codes).sum(dim=1) + parts.append(c_exact) + + # Tail CountSketch + if self.rank_tail > 0: + signed_full = g_flat * self.sign.unsqueeze(0) + c_tail = g_flat.new_zeros(N, self.rank_tail) + c_tail.scatter_add_(1, self.bucket.unsqueeze(0).expand(N, V), signed_full) + + # Remove gold contribution from tail + gold_bucket = self.bucket[safe_t] + gold_sign = self.sign[safe_t] + rows = torch.arange(N, device=device) + c_tail[rows, gold_bucket] -= gold_grad.squeeze(1) * gold_sign + + # Remove top-k from tail + if k_eff > 0: + top_bucket = self.bucket[topi] + top_sign_vals = self.sign[topi] + r_idx = rows.unsqueeze(1).expand(-1, k_eff).reshape(-1) + c_tail[r_idx, top_bucket.reshape(-1)] -= (topv * top_sign_vals).reshape(-1) + + parts.append(c_tail) + + c = torch.cat(parts, dim=1) + return c.reshape(*orig_shape, self.rank).to(dtype=grad_logits.dtype) + + +class _FactorizedExitFn(torch.autograd.Function): + @staticmethod + def forward(ctx, h, weight, targets, U, alpha, compressor): + logits = h @ weight.t() + ctx.compressor = compressor + ctx.save_for_backward(h.detach(), weight.detach(), targets, U.detach(), alpha.detach()) + ctx.logits_detached = logits.detach() + return logits + + @staticmethod + def backward(ctx, grad_logits): + h, weight, targets, U, alpha = ctx.saved_tensors + compressor = ctx.compressor + logits = ctx.logits_detached + + # Exact W gradient (no transport) + g_flat = grad_logits.reshape(-1, grad_logits.size(-1)).float() + h_flat = h.reshape(-1, h.size(-1)).float() + grad_weight = g_flat.t() @ h_flat + + # BP-free hidden signal via compressor + c = compressor.compress(grad_logits, targets).float() + grad_h = (alpha * c) @ U.float().t() + + return grad_h.to(h.dtype), grad_weight.to(weight.dtype), None, None, None, None + + +class _ExactParallelExitFn(torch.autograd.Function): + """Exit backward using only the exact recoverable parallel component. + + g_h_parallel = ((p-y)^T z / (||h||^2 + eps)) * h + + This is the ONLY component of W_U^T(p-y) that is identifiable from + forward quantities alone. The h-perp component is informationally + invisible without W_U. + """ + @staticmethod + def forward(ctx, h, weight, targets, residual_fn): + logits = h @ weight.t() + ctx.save_for_backward(h.detach(), weight.detach(), targets) + ctx.logits_detached = logits.detach() + ctx.residual_fn = residual_fn # optional h-perp residual + return logits + + @staticmethod + def backward(ctx, grad_logits): + h, weight, targets = ctx.saved_tensors + logits = ctx.logits_detached + residual_fn = ctx.residual_fn + + # Exact W gradient (no transport) + g_flat = grad_logits.reshape(-1, grad_logits.size(-1)).float() + h_flat = h.reshape(-1, h.size(-1)).float() + grad_weight = g_flat.t() @ h_flat + + # Exact parallel component: (p^T z - z_y) / (||h||^2 + eps) * h + # Memory-efficient: avoid materializing y_onehot (B,T,V) tensor. + # e = p - y_onehot computed in-place by subtracting 1 at target indices. + p = F.softmax(logits, dim=-1) # (..., V) + V = p.size(-1) + e = p # in-place: e will be (p - y_onehot) + target_idx = targets.clamp(min=0).unsqueeze(-1) + e.scatter_add_(-1, target_idx, torch.full_like(target_idx, -1.0, dtype=e.dtype)) + + # p^T z - z_y = (p-y)^T z (since p^T z - z_y = sum_j p_j z_j - z_y) + e_dot_z = (e * logits).sum(dim=-1, keepdim=True) # (..., 1) + h_norm_sq = (h.float() * h.float()).sum(dim=-1, keepdim=True) + 1e-8 # (..., 1) + + grad_h = (e_dot_z / h_norm_sq) * h.float() # (..., d) + + # Optional orthogonal residual + if residual_fn is not None: + residual = residual_fn(h.float(), e, logits, targets) + grad_h = grad_h + residual + + return grad_h.to(h.dtype), grad_weight.to(weight.dtype), None, None + + +class FactorizedExitHead(nn.Module): + """Drop-in BP-free local CE exit head.""" + + def __init__(self, d_model, vocab_size, *, mode="hybrid", rank=128, + rank_exact=32, topk=8, alpha_init=1.0, seed=None): + super().__init__() + if mode == "dense": + self.compressor = DenseRandomCompressor(vocab_size, rank, seed=seed) + elif mode == "hybrid": + self.compressor = HybridTopKTailSketchCompressor( + vocab_size, rank_exact=rank_exact, rank_tail=rank - rank_exact, topk=topk, seed=seed + ) + else: + raise ValueError(f"Unknown mode: {mode}") + + U = orthonormal_columns(d_model, self.compressor.rank, seed=seed) + self.register_buffer("U", U) + self.register_buffer("alpha", torch.tensor(alpha_init)) + self.vocab_size = vocab_size + + def forward(self, h, shared_weight, targets): + """h: (B,T,d), shared_weight: (V,d), targets: (B,T) → logits: (B,T,V)""" + return _FactorizedExitFn.apply(h, shared_weight, targets, self.U, self.alpha, self.compressor) + + +class ExactParallelExitHead(nn.Module): + """BP-free exit using exact parallel component + optional h-perp residual. + + Modes: + parallel_only: g̃_h = (e^T z / ||h||²) h (exact parallel only) + parallel_gold: + λ R(h) (e_y q_y) (+ gold token code in h⊥) + parallel_topmass: + λ R(h) (e_y q_y + Σ_{j∈S} e_j q_j) (+ top-mass codes in h⊥) + """ + + def __init__(self, d_model, vocab_size, *, mode="parallel_only", + residual_rank=32, residual_lambda=0.1, mass_threshold=0.95, seed=None): + super().__init__() + self.vocab_size = vocab_size + self.mode = mode + self.residual_lambda = residual_lambda + self.mass_threshold = mass_threshold + + if mode in ("parallel_gold", "parallel_topmass"): + # Fixed random token codes for h-perp residual + gen = torch.Generator() + if seed is not None: + gen.manual_seed(seed) + codes = torch.randn(vocab_size, residual_rank, generator=gen) + codes = codes / codes.norm(dim=-1, keepdim=True).clamp_min(1e-6) + self.register_buffer("token_codes", codes) + + # Fixed base Q for constructing R(h) ∈ h⊥ + Q = orthonormal_columns(d_model, residual_rank, seed=seed) + self.register_buffer("Q_base", Q) + else: + self.token_codes = None + self.Q_base = None + + def _residual_fn(self, h, e, logits, targets): + """Compute h-perp residual: λ R(h) C_head(e).""" + if self.mode == "parallel_only" or self.token_codes is None: + return None + + B_T = h.shape[:-1] + d = h.size(-1) + device = h.device + + # R(h): project Q_base into h⊥ (memory-efficient: avoid materializing (N, d, r) tensor) + h_hat = h / (h.norm(dim=-1, keepdim=True) + 1e-8) # (..., d) + Q = self.Q_base.float() # (d, r) + hQ = (h_hat.unsqueeze(-2) @ Q).squeeze(-2) # (..., r) = h_hat^T Q per token + # Column norms of Q_bar = sqrt(1 - hQ_j^2) (since Q cols are unit-norm, h_hat unit-norm) + col_norm_sq = (1.0 - hQ ** 2).clamp_min(1e-8) # (..., r) + col_norm_inv = col_norm_sq.rsqrt() # (..., r) + + # C_head(e): gold + (optionally) top-mass codes + t_flat = targets.reshape(-1).clamp(min=0) + e_flat = e.reshape(-1, self.vocab_size) + N = e_flat.size(0) + + codes = self.token_codes.float() + gold_grad = e_flat.gather(1, t_flat.unsqueeze(1)) # (N, 1) + c = gold_grad * codes[t_flat] # (N, r) + + if self.mode == "parallel_topmass": + # Adaptive top-mass via topk(200) + cumulative mass (avoids full sort OOM) + p_flat = F.softmax(logits.reshape(-1, self.vocab_size).float(), dim=-1) + k_pre = min(200, self.vocab_size - 1) + top_p, top_idx = p_flat.topk(k_pre, dim=1) # (N, k_pre) + top_p_cumsum = top_p.cumsum(dim=-1) + keep_mask = top_p_cumsum <= self.mass_threshold + keep_mask[:, 0] = True + # Get corresponding error values and codes (chunked to avoid OOM on codes[top_idx]) + top_e = e_flat.gather(1, top_idx) # (N, k_pre) + top_e_masked = top_e * keep_mask.float() + chunk_size = min(1024, N) + for cs in range(0, N, chunk_size): + ce = min(cs + chunk_size, N) + chunk_codes = codes[top_idx[cs:ce]] # (chunk, k_pre, r) + c[cs:ce] += (top_e_masked[cs:ce].unsqueeze(-1) * chunk_codes).sum(dim=1) + + c = c.reshape(*B_T, -1) # (..., r) + + # R(h) @ c: project into h⊥ (memory-efficient, O(N*(d+r)) instead of O(N*d*r)) + # residual = Σ_j (c_j / ||Q_bar_j||) * (Q_j - h_hat * hQ_j) + # = Q @ c_adj - h_hat * (c_adj · hQ) + c_adj = c * col_norm_inv # (..., r) + residual = c_adj @ Q.t() - h_hat * (c_adj * hQ).sum(dim=-1, keepdim=True) + return self.residual_lambda * residual + + def forward(self, h, shared_weight, targets): + residual_fn = self._residual_fn if self.mode != "parallel_only" else None + return _ExactParallelExitFn.apply(h, shared_weight, targets, residual_fn) |
