"""Factorized BP-free exit feedback for local CE training. Replaces W_U^T(p-y) with α · C(p-y) @ U^T where: C: fixed compressor (dense random or hybrid gold+topk+tail-sketch) U: fixed orthonormal expander (d, r) α: scalar gain Forward logits = h @ W_U^T (exact, unchanged) grad_W = exact local CE gradient (no weight transport) grad_h = factorized BP-free signal (no W_U^T) Two compressor modes: dense: g @ C where C is (V, r) fixed random hybrid: [gold + top-k exact codes, CountSketch(tail)] """ import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F def orthonormal_columns(d_out, rank, *, device=None, dtype=torch.float32, seed=None): if rank <= 0 or rank > d_out: raise ValueError(f"rank must satisfy 1 <= rank <= d_out, got rank={rank}, d_out={d_out}") gen = torch.Generator(device="cpu") if seed is not None: gen.manual_seed(seed) q, _ = torch.linalg.qr( torch.randn(d_out, rank, dtype=dtype, generator=gen), mode="reduced" ) return q.contiguous().to(device=device) class DenseRandomCompressor(nn.Module): def __init__(self, vocab_size, rank, *, seed=None): super().__init__() self.vocab_size = vocab_size self._rank = rank gen = torch.Generator() if seed is not None: gen.manual_seed(seed) codebook = torch.randn(vocab_size, rank, generator=gen) / math.sqrt(rank) self.register_buffer("codebook", codebook) @property def rank(self): return self._rank @torch.no_grad() def compress(self, grad_logits, targets): return grad_logits.float() @ self.codebook.float() class HybridTopKTailSketchCompressor(nn.Module): def __init__(self, vocab_size, *, rank_exact=32, rank_tail=96, topk=8, seed=None): super().__init__() self.vocab_size = vocab_size self.rank_exact = rank_exact self.rank_tail = rank_tail self.topk = topk gen = torch.Generator() if seed is not None: gen.manual_seed(seed) if rank_exact > 0: codes = torch.randn(vocab_size, rank_exact, generator=gen) codes = codes / codes.norm(dim=-1, keepdim=True).clamp_min(1e-6) else: codes = torch.empty(vocab_size, 0) self.register_buffer("exact_codes", codes) if rank_tail > 0: bucket = torch.randint(0, rank_tail, (vocab_size,), generator=gen) sign = torch.randint(0, 2, (vocab_size,), generator=gen).float() * 2 - 1 else: bucket = torch.empty(vocab_size, dtype=torch.long) sign = torch.empty(vocab_size) self.register_buffer("bucket", bucket) self.register_buffer("sign", sign) @property def rank(self): return self.rank_exact + self.rank_tail @torch.no_grad() def compress(self, grad_logits, targets): g = grad_logits.float() V = g.size(-1) orig_shape = g.shape[:-1] g_flat = g.reshape(-1, V) t_flat = targets.reshape(-1).long() N = g_flat.size(0) device = g_flat.device safe_t = t_flat.clamp(min=0) gold_grad = g_flat.gather(1, safe_t.unsqueeze(1)) # (N, 1) k_eff = min(self.topk, max(V - 1, 0)) parts = [] # Exact head: gold + top-k if self.rank_exact > 0: codes = self.exact_codes.float() gold_codes = codes[safe_t] # (N, r_exact) c_exact = gold_grad * gold_codes if k_eff > 0: topv, topi = g_flat.topk(k_eff, dim=1) top_codes = codes[topi] # (N, k, r_exact) c_exact = c_exact + (topv.unsqueeze(-1) * top_codes).sum(dim=1) parts.append(c_exact) # Tail CountSketch if self.rank_tail > 0: signed_full = g_flat * self.sign.unsqueeze(0) c_tail = g_flat.new_zeros(N, self.rank_tail) c_tail.scatter_add_(1, self.bucket.unsqueeze(0).expand(N, V), signed_full) # Remove gold contribution from tail gold_bucket = self.bucket[safe_t] gold_sign = self.sign[safe_t] rows = torch.arange(N, device=device) c_tail[rows, gold_bucket] -= gold_grad.squeeze(1) * gold_sign # Remove top-k from tail if k_eff > 0: top_bucket = self.bucket[topi] top_sign_vals = self.sign[topi] r_idx = rows.unsqueeze(1).expand(-1, k_eff).reshape(-1) c_tail[r_idx, top_bucket.reshape(-1)] -= (topv * top_sign_vals).reshape(-1) parts.append(c_tail) c = torch.cat(parts, dim=1) return c.reshape(*orig_shape, self.rank).to(dtype=grad_logits.dtype) class _FactorizedExitFn(torch.autograd.Function): @staticmethod def forward(ctx, h, weight, targets, U, alpha, compressor): logits = h @ weight.t() ctx.compressor = compressor ctx.save_for_backward(h.detach(), weight.detach(), targets, U.detach(), alpha.detach()) ctx.logits_detached = logits.detach() return logits @staticmethod def backward(ctx, grad_logits): h, weight, targets, U, alpha = ctx.saved_tensors compressor = ctx.compressor logits = ctx.logits_detached # Exact W gradient (no transport) g_flat = grad_logits.reshape(-1, grad_logits.size(-1)).float() h_flat = h.reshape(-1, h.size(-1)).float() grad_weight = g_flat.t() @ h_flat # BP-free hidden signal via compressor c = compressor.compress(grad_logits, targets).float() grad_h = (alpha * c) @ U.float().t() return grad_h.to(h.dtype), grad_weight.to(weight.dtype), None, None, None, None class _ExactParallelExitFn(torch.autograd.Function): """Exit backward using only the exact recoverable parallel component. g_h_parallel = ((p-y)^T z / (||h||^2 + eps)) * h This is the ONLY component of W_U^T(p-y) that is identifiable from forward quantities alone. The h-perp component is informationally invisible without W_U. """ @staticmethod def forward(ctx, h, weight, targets, residual_fn): logits = h @ weight.t() ctx.save_for_backward(h.detach(), weight.detach(), targets) ctx.logits_detached = logits.detach() ctx.residual_fn = residual_fn # optional h-perp residual return logits @staticmethod def backward(ctx, grad_logits): h, weight, targets = ctx.saved_tensors logits = ctx.logits_detached residual_fn = ctx.residual_fn # Exact W gradient (no transport) g_flat = grad_logits.reshape(-1, grad_logits.size(-1)).float() h_flat = h.reshape(-1, h.size(-1)).float() grad_weight = g_flat.t() @ h_flat # Exact parallel component: (p^T z - z_y) / (||h||^2 + eps) * h # Memory-efficient: avoid materializing y_onehot (B,T,V) tensor. # e = p - y_onehot computed in-place by subtracting 1 at target indices. p = F.softmax(logits, dim=-1) # (..., V) V = p.size(-1) e = p # in-place: e will be (p - y_onehot) target_idx = targets.clamp(min=0).unsqueeze(-1) e.scatter_add_(-1, target_idx, torch.full_like(target_idx, -1.0, dtype=e.dtype)) # p^T z - z_y = (p-y)^T z (since p^T z - z_y = sum_j p_j z_j - z_y) e_dot_z = (e * logits).sum(dim=-1, keepdim=True) # (..., 1) h_norm_sq = (h.float() * h.float()).sum(dim=-1, keepdim=True) + 1e-8 # (..., 1) grad_h = (e_dot_z / h_norm_sq) * h.float() # (..., d) # Optional orthogonal residual if residual_fn is not None: residual = residual_fn(h.float(), e, logits, targets) grad_h = grad_h + residual return grad_h.to(h.dtype), grad_weight.to(weight.dtype), None, None class FactorizedExitHead(nn.Module): """Drop-in BP-free local CE exit head.""" def __init__(self, d_model, vocab_size, *, mode="hybrid", rank=128, rank_exact=32, topk=8, alpha_init=1.0, seed=None): super().__init__() if mode == "dense": self.compressor = DenseRandomCompressor(vocab_size, rank, seed=seed) elif mode == "hybrid": self.compressor = HybridTopKTailSketchCompressor( vocab_size, rank_exact=rank_exact, rank_tail=rank - rank_exact, topk=topk, seed=seed ) else: raise ValueError(f"Unknown mode: {mode}") U = orthonormal_columns(d_model, self.compressor.rank, seed=seed) self.register_buffer("U", U) self.register_buffer("alpha", torch.tensor(alpha_init)) self.vocab_size = vocab_size def forward(self, h, shared_weight, targets): """h: (B,T,d), shared_weight: (V,d), targets: (B,T) → logits: (B,T,V)""" return _FactorizedExitFn.apply(h, shared_weight, targets, self.U, self.alpha, self.compressor) class ExactParallelExitHead(nn.Module): """BP-free exit using exact parallel component + optional h-perp residual. Modes: parallel_only: g̃_h = (e^T z / ||h||²) h (exact parallel only) parallel_gold: + λ R(h) (e_y q_y) (+ gold token code in h⊥) parallel_topmass: + λ R(h) (e_y q_y + Σ_{j∈S} e_j q_j) (+ top-mass codes in h⊥) """ def __init__(self, d_model, vocab_size, *, mode="parallel_only", residual_rank=32, residual_lambda=0.1, mass_threshold=0.95, seed=None): super().__init__() self.vocab_size = vocab_size self.mode = mode self.residual_lambda = residual_lambda self.mass_threshold = mass_threshold if mode in ("parallel_gold", "parallel_topmass"): # Fixed random token codes for h-perp residual gen = torch.Generator() if seed is not None: gen.manual_seed(seed) codes = torch.randn(vocab_size, residual_rank, generator=gen) codes = codes / codes.norm(dim=-1, keepdim=True).clamp_min(1e-6) self.register_buffer("token_codes", codes) # Fixed base Q for constructing R(h) ∈ h⊥ Q = orthonormal_columns(d_model, residual_rank, seed=seed) self.register_buffer("Q_base", Q) else: self.token_codes = None self.Q_base = None def _residual_fn(self, h, e, logits, targets): """Compute h-perp residual: λ R(h) C_head(e).""" if self.mode == "parallel_only" or self.token_codes is None: return None B_T = h.shape[:-1] d = h.size(-1) device = h.device # R(h): project Q_base into h⊥ (memory-efficient: avoid materializing (N, d, r) tensor) h_hat = h / (h.norm(dim=-1, keepdim=True) + 1e-8) # (..., d) Q = self.Q_base.float() # (d, r) hQ = (h_hat.unsqueeze(-2) @ Q).squeeze(-2) # (..., r) = h_hat^T Q per token # Column norms of Q_bar = sqrt(1 - hQ_j^2) (since Q cols are unit-norm, h_hat unit-norm) col_norm_sq = (1.0 - hQ ** 2).clamp_min(1e-8) # (..., r) col_norm_inv = col_norm_sq.rsqrt() # (..., r) # C_head(e): gold + (optionally) top-mass codes t_flat = targets.reshape(-1).clamp(min=0) e_flat = e.reshape(-1, self.vocab_size) N = e_flat.size(0) codes = self.token_codes.float() gold_grad = e_flat.gather(1, t_flat.unsqueeze(1)) # (N, 1) c = gold_grad * codes[t_flat] # (N, r) if self.mode == "parallel_topmass": # Adaptive top-mass via topk(200) + cumulative mass (avoids full sort OOM) p_flat = F.softmax(logits.reshape(-1, self.vocab_size).float(), dim=-1) k_pre = min(200, self.vocab_size - 1) top_p, top_idx = p_flat.topk(k_pre, dim=1) # (N, k_pre) top_p_cumsum = top_p.cumsum(dim=-1) keep_mask = top_p_cumsum <= self.mass_threshold keep_mask[:, 0] = True # Get corresponding error values and codes (chunked to avoid OOM on codes[top_idx]) top_e = e_flat.gather(1, top_idx) # (N, k_pre) top_e_masked = top_e * keep_mask.float() chunk_size = min(1024, N) for cs in range(0, N, chunk_size): ce = min(cs + chunk_size, N) chunk_codes = codes[top_idx[cs:ce]] # (chunk, k_pre, r) c[cs:ce] += (top_e_masked[cs:ce].unsqueeze(-1) * chunk_codes).sum(dim=1) c = c.reshape(*B_T, -1) # (..., r) # R(h) @ c: project into h⊥ (memory-efficient, O(N*(d+r)) instead of O(N*d*r)) # residual = Σ_j (c_j / ||Q_bar_j||) * (Q_j - h_hat * hQ_j) # = Q @ c_adj - h_hat * (c_adj · hQ) c_adj = c * col_norm_inv # (..., r) residual = c_adj @ Q.t() - h_hat * (c_adj * hQ).sum(dim=-1, keepdim=True) return self.residual_lambda * residual def forward(self, h, shared_weight, targets): residual_fn = self._residual_fn if self.mode != "parallel_only" else None return _ExactParallelExitFn.apply(h, shared_weight, targets, residual_fn)