summaryrefslogtreecommitdiff
path: root/ep_run/factorized_exit.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/factorized_exit.py')
-rw-r--r--ep_run/factorized_exit.py330
1 files changed, 330 insertions, 0 deletions
diff --git a/ep_run/factorized_exit.py b/ep_run/factorized_exit.py
new file mode 100644
index 0000000..fbf66a8
--- /dev/null
+++ b/ep_run/factorized_exit.py
@@ -0,0 +1,330 @@
+"""Factorized BP-free exit feedback for local CE training.
+
+Replaces W_U^T(p-y) with α · C(p-y) @ U^T where:
+ C: fixed compressor (dense random or hybrid gold+topk+tail-sketch)
+ U: fixed orthonormal expander (d, r)
+ α: scalar gain
+
+Forward logits = h @ W_U^T (exact, unchanged)
+grad_W = exact local CE gradient (no weight transport)
+grad_h = factorized BP-free signal (no W_U^T)
+
+Two compressor modes:
+ dense: g @ C where C is (V, r) fixed random
+ hybrid: [gold + top-k exact codes, CountSketch(tail)]
+"""
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def orthonormal_columns(d_out, rank, *, device=None, dtype=torch.float32, seed=None):
+ if rank <= 0 or rank > d_out:
+ raise ValueError(f"rank must satisfy 1 <= rank <= d_out, got rank={rank}, d_out={d_out}")
+ gen = torch.Generator(device="cpu")
+ if seed is not None:
+ gen.manual_seed(seed)
+ q, _ = torch.linalg.qr(
+ torch.randn(d_out, rank, dtype=dtype, generator=gen), mode="reduced"
+ )
+ return q.contiguous().to(device=device)
+
+
+class DenseRandomCompressor(nn.Module):
+ def __init__(self, vocab_size, rank, *, seed=None):
+ super().__init__()
+ self.vocab_size = vocab_size
+ self._rank = rank
+ gen = torch.Generator()
+ if seed is not None:
+ gen.manual_seed(seed)
+ codebook = torch.randn(vocab_size, rank, generator=gen) / math.sqrt(rank)
+ self.register_buffer("codebook", codebook)
+
+ @property
+ def rank(self):
+ return self._rank
+
+ @torch.no_grad()
+ def compress(self, grad_logits, targets):
+ return grad_logits.float() @ self.codebook.float()
+
+
+class HybridTopKTailSketchCompressor(nn.Module):
+ def __init__(self, vocab_size, *, rank_exact=32, rank_tail=96, topk=8, seed=None):
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.rank_exact = rank_exact
+ self.rank_tail = rank_tail
+ self.topk = topk
+
+ gen = torch.Generator()
+ if seed is not None:
+ gen.manual_seed(seed)
+
+ if rank_exact > 0:
+ codes = torch.randn(vocab_size, rank_exact, generator=gen)
+ codes = codes / codes.norm(dim=-1, keepdim=True).clamp_min(1e-6)
+ else:
+ codes = torch.empty(vocab_size, 0)
+ self.register_buffer("exact_codes", codes)
+
+ if rank_tail > 0:
+ bucket = torch.randint(0, rank_tail, (vocab_size,), generator=gen)
+ sign = torch.randint(0, 2, (vocab_size,), generator=gen).float() * 2 - 1
+ else:
+ bucket = torch.empty(vocab_size, dtype=torch.long)
+ sign = torch.empty(vocab_size)
+ self.register_buffer("bucket", bucket)
+ self.register_buffer("sign", sign)
+
+ @property
+ def rank(self):
+ return self.rank_exact + self.rank_tail
+
+ @torch.no_grad()
+ def compress(self, grad_logits, targets):
+ g = grad_logits.float()
+ V = g.size(-1)
+ orig_shape = g.shape[:-1]
+ g_flat = g.reshape(-1, V)
+ t_flat = targets.reshape(-1).long()
+ N = g_flat.size(0)
+ device = g_flat.device
+
+ safe_t = t_flat.clamp(min=0)
+ gold_grad = g_flat.gather(1, safe_t.unsqueeze(1)) # (N, 1)
+
+ k_eff = min(self.topk, max(V - 1, 0))
+ parts = []
+
+ # Exact head: gold + top-k
+ if self.rank_exact > 0:
+ codes = self.exact_codes.float()
+ gold_codes = codes[safe_t] # (N, r_exact)
+ c_exact = gold_grad * gold_codes
+ if k_eff > 0:
+ topv, topi = g_flat.topk(k_eff, dim=1)
+ top_codes = codes[topi] # (N, k, r_exact)
+ c_exact = c_exact + (topv.unsqueeze(-1) * top_codes).sum(dim=1)
+ parts.append(c_exact)
+
+ # Tail CountSketch
+ if self.rank_tail > 0:
+ signed_full = g_flat * self.sign.unsqueeze(0)
+ c_tail = g_flat.new_zeros(N, self.rank_tail)
+ c_tail.scatter_add_(1, self.bucket.unsqueeze(0).expand(N, V), signed_full)
+
+ # Remove gold contribution from tail
+ gold_bucket = self.bucket[safe_t]
+ gold_sign = self.sign[safe_t]
+ rows = torch.arange(N, device=device)
+ c_tail[rows, gold_bucket] -= gold_grad.squeeze(1) * gold_sign
+
+ # Remove top-k from tail
+ if k_eff > 0:
+ top_bucket = self.bucket[topi]
+ top_sign_vals = self.sign[topi]
+ r_idx = rows.unsqueeze(1).expand(-1, k_eff).reshape(-1)
+ c_tail[r_idx, top_bucket.reshape(-1)] -= (topv * top_sign_vals).reshape(-1)
+
+ parts.append(c_tail)
+
+ c = torch.cat(parts, dim=1)
+ return c.reshape(*orig_shape, self.rank).to(dtype=grad_logits.dtype)
+
+
+class _FactorizedExitFn(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, h, weight, targets, U, alpha, compressor):
+ logits = h @ weight.t()
+ ctx.compressor = compressor
+ ctx.save_for_backward(h.detach(), weight.detach(), targets, U.detach(), alpha.detach())
+ ctx.logits_detached = logits.detach()
+ return logits
+
+ @staticmethod
+ def backward(ctx, grad_logits):
+ h, weight, targets, U, alpha = ctx.saved_tensors
+ compressor = ctx.compressor
+ logits = ctx.logits_detached
+
+ # Exact W gradient (no transport)
+ g_flat = grad_logits.reshape(-1, grad_logits.size(-1)).float()
+ h_flat = h.reshape(-1, h.size(-1)).float()
+ grad_weight = g_flat.t() @ h_flat
+
+ # BP-free hidden signal via compressor
+ c = compressor.compress(grad_logits, targets).float()
+ grad_h = (alpha * c) @ U.float().t()
+
+ return grad_h.to(h.dtype), grad_weight.to(weight.dtype), None, None, None, None
+
+
+class _ExactParallelExitFn(torch.autograd.Function):
+ """Exit backward using only the exact recoverable parallel component.
+
+ g_h_parallel = ((p-y)^T z / (||h||^2 + eps)) * h
+
+ This is the ONLY component of W_U^T(p-y) that is identifiable from
+ forward quantities alone. The h-perp component is informationally
+ invisible without W_U.
+ """
+ @staticmethod
+ def forward(ctx, h, weight, targets, residual_fn):
+ logits = h @ weight.t()
+ ctx.save_for_backward(h.detach(), weight.detach(), targets)
+ ctx.logits_detached = logits.detach()
+ ctx.residual_fn = residual_fn # optional h-perp residual
+ return logits
+
+ @staticmethod
+ def backward(ctx, grad_logits):
+ h, weight, targets = ctx.saved_tensors
+ logits = ctx.logits_detached
+ residual_fn = ctx.residual_fn
+
+ # Exact W gradient (no transport)
+ g_flat = grad_logits.reshape(-1, grad_logits.size(-1)).float()
+ h_flat = h.reshape(-1, h.size(-1)).float()
+ grad_weight = g_flat.t() @ h_flat
+
+ # Exact parallel component: (p^T z - z_y) / (||h||^2 + eps) * h
+ # Memory-efficient: avoid materializing y_onehot (B,T,V) tensor.
+ # e = p - y_onehot computed in-place by subtracting 1 at target indices.
+ p = F.softmax(logits, dim=-1) # (..., V)
+ V = p.size(-1)
+ e = p # in-place: e will be (p - y_onehot)
+ target_idx = targets.clamp(min=0).unsqueeze(-1)
+ e.scatter_add_(-1, target_idx, torch.full_like(target_idx, -1.0, dtype=e.dtype))
+
+ # p^T z - z_y = (p-y)^T z (since p^T z - z_y = sum_j p_j z_j - z_y)
+ e_dot_z = (e * logits).sum(dim=-1, keepdim=True) # (..., 1)
+ h_norm_sq = (h.float() * h.float()).sum(dim=-1, keepdim=True) + 1e-8 # (..., 1)
+
+ grad_h = (e_dot_z / h_norm_sq) * h.float() # (..., d)
+
+ # Optional orthogonal residual
+ if residual_fn is not None:
+ residual = residual_fn(h.float(), e, logits, targets)
+ grad_h = grad_h + residual
+
+ return grad_h.to(h.dtype), grad_weight.to(weight.dtype), None, None
+
+
+class FactorizedExitHead(nn.Module):
+ """Drop-in BP-free local CE exit head."""
+
+ def __init__(self, d_model, vocab_size, *, mode="hybrid", rank=128,
+ rank_exact=32, topk=8, alpha_init=1.0, seed=None):
+ super().__init__()
+ if mode == "dense":
+ self.compressor = DenseRandomCompressor(vocab_size, rank, seed=seed)
+ elif mode == "hybrid":
+ self.compressor = HybridTopKTailSketchCompressor(
+ vocab_size, rank_exact=rank_exact, rank_tail=rank - rank_exact, topk=topk, seed=seed
+ )
+ else:
+ raise ValueError(f"Unknown mode: {mode}")
+
+ U = orthonormal_columns(d_model, self.compressor.rank, seed=seed)
+ self.register_buffer("U", U)
+ self.register_buffer("alpha", torch.tensor(alpha_init))
+ self.vocab_size = vocab_size
+
+ def forward(self, h, shared_weight, targets):
+ """h: (B,T,d), shared_weight: (V,d), targets: (B,T) → logits: (B,T,V)"""
+ return _FactorizedExitFn.apply(h, shared_weight, targets, self.U, self.alpha, self.compressor)
+
+
+class ExactParallelExitHead(nn.Module):
+ """BP-free exit using exact parallel component + optional h-perp residual.
+
+ Modes:
+ parallel_only: g̃_h = (e^T z / ||h||²) h (exact parallel only)
+ parallel_gold: + λ R(h) (e_y q_y) (+ gold token code in h⊥)
+ parallel_topmass: + λ R(h) (e_y q_y + Σ_{j∈S} e_j q_j) (+ top-mass codes in h⊥)
+ """
+
+ def __init__(self, d_model, vocab_size, *, mode="parallel_only",
+ residual_rank=32, residual_lambda=0.1, mass_threshold=0.95, seed=None):
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.mode = mode
+ self.residual_lambda = residual_lambda
+ self.mass_threshold = mass_threshold
+
+ if mode in ("parallel_gold", "parallel_topmass"):
+ # Fixed random token codes for h-perp residual
+ gen = torch.Generator()
+ if seed is not None:
+ gen.manual_seed(seed)
+ codes = torch.randn(vocab_size, residual_rank, generator=gen)
+ codes = codes / codes.norm(dim=-1, keepdim=True).clamp_min(1e-6)
+ self.register_buffer("token_codes", codes)
+
+ # Fixed base Q for constructing R(h) ∈ h⊥
+ Q = orthonormal_columns(d_model, residual_rank, seed=seed)
+ self.register_buffer("Q_base", Q)
+ else:
+ self.token_codes = None
+ self.Q_base = None
+
+ def _residual_fn(self, h, e, logits, targets):
+ """Compute h-perp residual: λ R(h) C_head(e)."""
+ if self.mode == "parallel_only" or self.token_codes is None:
+ return None
+
+ B_T = h.shape[:-1]
+ d = h.size(-1)
+ device = h.device
+
+ # R(h): project Q_base into h⊥ (memory-efficient: avoid materializing (N, d, r) tensor)
+ h_hat = h / (h.norm(dim=-1, keepdim=True) + 1e-8) # (..., d)
+ Q = self.Q_base.float() # (d, r)
+ hQ = (h_hat.unsqueeze(-2) @ Q).squeeze(-2) # (..., r) = h_hat^T Q per token
+ # Column norms of Q_bar = sqrt(1 - hQ_j^2) (since Q cols are unit-norm, h_hat unit-norm)
+ col_norm_sq = (1.0 - hQ ** 2).clamp_min(1e-8) # (..., r)
+ col_norm_inv = col_norm_sq.rsqrt() # (..., r)
+
+ # C_head(e): gold + (optionally) top-mass codes
+ t_flat = targets.reshape(-1).clamp(min=0)
+ e_flat = e.reshape(-1, self.vocab_size)
+ N = e_flat.size(0)
+
+ codes = self.token_codes.float()
+ gold_grad = e_flat.gather(1, t_flat.unsqueeze(1)) # (N, 1)
+ c = gold_grad * codes[t_flat] # (N, r)
+
+ if self.mode == "parallel_topmass":
+ # Adaptive top-mass via topk(200) + cumulative mass (avoids full sort OOM)
+ p_flat = F.softmax(logits.reshape(-1, self.vocab_size).float(), dim=-1)
+ k_pre = min(200, self.vocab_size - 1)
+ top_p, top_idx = p_flat.topk(k_pre, dim=1) # (N, k_pre)
+ top_p_cumsum = top_p.cumsum(dim=-1)
+ keep_mask = top_p_cumsum <= self.mass_threshold
+ keep_mask[:, 0] = True
+ # Get corresponding error values and codes (chunked to avoid OOM on codes[top_idx])
+ top_e = e_flat.gather(1, top_idx) # (N, k_pre)
+ top_e_masked = top_e * keep_mask.float()
+ chunk_size = min(1024, N)
+ for cs in range(0, N, chunk_size):
+ ce = min(cs + chunk_size, N)
+ chunk_codes = codes[top_idx[cs:ce]] # (chunk, k_pre, r)
+ c[cs:ce] += (top_e_masked[cs:ce].unsqueeze(-1) * chunk_codes).sum(dim=1)
+
+ c = c.reshape(*B_T, -1) # (..., r)
+
+ # R(h) @ c: project into h⊥ (memory-efficient, O(N*(d+r)) instead of O(N*d*r))
+ # residual = Σ_j (c_j / ||Q_bar_j||) * (Q_j - h_hat * hQ_j)
+ # = Q @ c_adj - h_hat * (c_adj · hQ)
+ c_adj = c * col_norm_inv # (..., r)
+ residual = c_adj @ Q.t() - h_hat * (c_adj * hQ).sum(dim=-1, keepdim=True)
+ return self.residual_lambda * residual
+
+ def forward(self, h, shared_weight, targets):
+ residual_fn = self._residual_fn if self.mode != "parallel_only" else None
+ return _ExactParallelExitFn.apply(h, shared_weight, targets, residual_fn)