diff options
Diffstat (limited to 'ep_run/model_local.py')
| -rw-r--r-- | ep_run/model_local.py | 470 |
1 files changed, 470 insertions, 0 deletions
diff --git a/ep_run/model_local.py b/ep_run/model_local.py new file mode 100644 index 0000000..a84c692 --- /dev/null +++ b/ep_run/model_local.py @@ -0,0 +1,470 @@ +"""Sigmoid GPT with split Q/K/V projections and LocalLinear for method dispatch. + +Derived from model.py but uses LocalLinear for every linear layer and splits +the fused qkv into separate q_proj, k_proj, v_proj so that each projection has +its own feedback matrix for FA / DFA / sign_sym. +""" +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from local_layers import LocalLinear + + +class SigmoidSTE(torch.autograd.Function): + """Sigmoid forward, straight-through backward (skip A(1-A) derivative).""" + @staticmethod + def forward(ctx, x): + return torch.sigmoid(x) + @staticmethod + def backward(ctx, grad_out): + return grad_out + + +class GELUSTE(torch.autograd.Function): + """GELU forward, straight-through backward (skip gelu' derivative).""" + @staticmethod + def forward(ctx, x): + return F.gelu(x) + @staticmethod + def backward(ctx, grad_out): + return grad_out + + +class HardTopK(torch.autograd.Function): + """k-WTA: zero out all but top-k (by abs value) along last dim, in BOTH forward and backward. + + Forward: keep top-k entries, zero rest. + Backward: gradient mask = forward mask (only winners get gradient). + + This enforces strict sparsity — non-selected channels never update. + """ + @staticmethod + def forward(ctx, x, k): + topk_vals, topk_idx = x.abs().topk(k, dim=-1) + mask = torch.zeros_like(x).scatter_(-1, topk_idx, 1.0) + ctx.save_for_backward(mask) + return x * mask + + @staticmethod + def backward(ctx, grad_out): + (mask,) = ctx.saved_tensors + return grad_out * mask, None + + +class FrozenSubspace(nn.Module): + """Project h to fixed r-dim orthonormal subspace via Q Q^T h. + + Q ∈ R^{d × r} is a random orthonormal basis, frozen at init. + Output lives in span(Q) ⊂ R^d. (d-r) directions are killed. + + With same seed across blocks, all layers share the same subspace — + so the residual stream is constrained to span(Q) throughout the network. + + Differentiable (no STE): grad_h = Q Q^T grad_out (same projection). + For BPfree: residual codebook subspace ≈ span(Q) is exactly what BPfree + delivers gradient on, so feedback geometry is matched by construction. + """ + def __init__(self, d_model, rank, seed=42): + super().__init__() + self.rank = rank + gen = torch.Generator() + gen.manual_seed(seed) + Q, _ = torch.linalg.qr(torch.randn(d_model, rank, generator=gen)) + self.register_buffer("Q", Q) # (d, r) + + def forward(self, h): + # h @ Q → (..., r) coefficients in basis + # @ Q.t() → back to (..., d), now in span(Q) + return h @ self.Q @ self.Q.t() + + +class VQResidualDir(nn.Module): + """Directional quantization to fixed codebook with STE backward. + + Forward: replace h's direction with nearest of K fixed unit-norm codebook entries + (per token). Magnitude is preserved. + Backward: identity through h (STE — no gradient on the codebook lookup itself). + + Codebook is initialized with random unit-norm directions and FROZEN (registered buffer). + The "feature directions" are predefined — the network only learns *which code to land on* + per token per layer. Discrete bottleneck: log2(K) bits per token per layer. + + For BPfree: the gradient signal needed to switch between codes is in + {radial, low-rank residual} subspace, matching BPfree exit's bandwidth. + """ + def __init__(self, d_model, n_codes, seed=None): + super().__init__() + self.n_codes = n_codes + gen = torch.Generator() + if seed is not None: + gen.manual_seed(seed) + codes = torch.randn(n_codes, d_model, generator=gen) + codes = codes / codes.norm(dim=-1, keepdim=True).clamp_min(1e-8) + self.register_buffer("codebook", codes) + + def forward(self, h): + h_norm = h.norm(dim=-1, keepdim=True).clamp_min(1e-8) + h_hat = h / h_norm + sims = h_hat @ self.codebook.t() # (..., K) + idx = sims.argmax(dim=-1) + z_q_dir = self.codebook[idx] # (..., d) unit-norm direction + z_q = z_q_dir * h_norm # restore magnitude + # STE: forward = z_q, backward = identity through h + return h + (z_q - h).detach() + + +class LayerNormSTE(nn.Module): + """LayerNorm forward, straight-through backward (gradient passes through as identity).""" + def __init__(self, normalized_shape): + super().__init__() + self.ln = nn.LayerNorm(normalized_shape) + def forward(self, x): + with torch.no_grad(): + out = self.ln(x) + return x + (out - x).detach() + + +class _ProjectedSurrogateLNFn(torch.autograd.Function): + """Core autograd function for projected surrogate LN backward. + mode='projected': full P_z(v) = v - mean(v) - z*mean(v*z), scaled by 1/σ + mode='center_scale': only v - mean(v), scaled by 1/σ (no radial removal) + """ + @staticmethod + def forward(ctx, x, eps, mode): + x_f = x.float() if x.dtype in (torch.float16, torch.bfloat16) else x + mu = x_f.mean(dim=-1, keepdim=True) + xc = x_f - mu + var = (xc * xc).mean(dim=-1, keepdim=True) + rsigma = torch.rsqrt(var + eps) + z = xc * rsigma + ctx.save_for_backward(z, rsigma) + ctx.mode = mode + ctx.input_dtype = x.dtype + return z.to(dtype=x.dtype) + + @staticmethod + def backward(ctx, g_tilde): + z, rsigma = ctx.saved_tensors + v = g_tilde.float() if g_tilde.dtype in (torch.float16, torch.bfloat16) else g_tilde + v = v.to(dtype=z.dtype) + v_mean = v.mean(dim=-1, keepdim=True) + if ctx.mode == "projected": + vz_mean = (v * z).mean(dim=-1, keepdim=True) + p_v = v - v_mean - z * vz_mean + else: # center_scale + p_v = v - v_mean + g_x = p_v * rsigma + return g_x.to(dtype=ctx.input_dtype), None, None + + +class LayerNormProjectedSurrogate(nn.Module): + """LN forward = standard normalization. LN backward = projected surrogate (not BP). + mode='projected': full mean-center + radial removal + 1/σ scaling + mode='center_scale': mean-center + 1/σ only (no radial removal) + Affine (γ, β) handled outside the custom Function so g̃ = ∂L/∂z exactly. + """ + def __init__(self, normalized_shape, eps=1e-5, mode="projected", + elementwise_affine=False, bias=True): + super().__init__() + self.normalized_shape = normalized_shape + self.eps = eps + self.mode = mode + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias_param = nn.Parameter(torch.zeros(normalized_shape)) if bias else None + else: + self.weight = None + self.bias_param = None + + def forward(self, x): + z = _ProjectedSurrogateLNFn.apply(x, self.eps, self.mode) + if self.weight is not None: + z = z * self.weight + if self.bias_param is not None: + z = z + self.bias_param + return z + + +class SoftmaxValueMixLocalFn(torch.autograd.Function): + """Fused softmax(S) @ V with local backward. + + Forward: A = softmax(S), O = A @ V + Backward: g_S_{i,j} = A_{ij} * <δO_i, V_j - O_i> (no lateral sum!) + δV = A^T @ δO (attention-weighted gather) + + The softmax Jacobian's "lateral sum" Σ_j A_ij g_ij collapses to a per-query + scalar baseline <δO_i, O_i> when composed with A@V — pure algebra, not approximation. + """ + @staticmethod + def forward(ctx, scores, v): + attn = F.softmax(scores, dim=-1) + out = torch.einsum("bhtk,bhkd->bhtd", attn, v) + ctx.save_for_backward(attn.detach(), out.detach(), v.detach()) + return out + + @staticmethod + def backward(ctx, delta_out): + attn, out, v = ctx.saved_tensors + # g_A_{i,j} = <δO_i, V_j> + g_a = torch.einsum("bhtd,bhkd->bhtk", delta_out, v) + # baseline = <δO_i, O_i> per query (the "lateral sum" collapsed to this) + baseline = (delta_out * out).sum(dim=-1, keepdim=True) + # g_S_{i,j} = A_{ij} * (<δO_i, V_j> - <δO_i, O_i>) + g_scores = attn * (g_a - baseline) + # δV = A^T @ δO (value gradient) + delta_v = torch.einsum("bhtk,bhtd->bhkd", attn, delta_out) + return g_scores, delta_v + + +@dataclass +class LocalGPTConfig: + block_size: int = 256 + vocab_size: int = 65 + n_layer: int = 6 + n_head: int = 6 + n_embd: int = 384 + dropout: float = 0.2 + bias: bool = False + attn_mode: str = "sigmoid" + sigmoid_bias_mode: str = "neg_log_n" + method: str = "bp" # bp | fa | dfa | sign_sym + # STE ablation flags + ste_sigmoid: bool = False # skip A(1-A) in sigmoid attention backward + ste_gelu: bool = False # skip gelu' in FFN backward + freeze_emb: bool = False # freeze token + position embeddings + # LN backward mode: "bp" (standard), "ste" (identity), "center_scale", "projected" + ln_mode: str = "bp" + fuse_attn_local: bool = False # fuse softmax+A@V with local backward (no lateral sum) + # Sparsity options for SparseFormer experiments + mlp_topk: int = 0 # if > 0, apply hard top-k (k-WTA) to MLP hidden activation (4*n_embd dim) + resid_topk: int = 0 # if > 0, apply hard top-k to residual stream output of each block (n_embd dim) + # FrozenCodeFormer: directional VQ to fixed codebook at residual stream end + vq_codes: int = 0 # if > 0, apply VQResidualDir with K=vq_codes fixed unit-norm codebook entries + # FrozenSubspace: continuous r-dim subspace constraint (shared Q across all blocks) + subspace_rank: int = 0 # if > 0, project residual stream to fixed r-dim subspace at each block + # FA B-init mode (only used when method='fa'): gaussian | orthogonal | ortho_he | sparse + fa_init_mode: str = "gaussian" + fa_sparse_k: int = 0 # for fa_init_mode='sparse': non-zero entries per row (0 = auto in/16) + # GrAPE: per-step JVP-based cosine alignment of B toward true Jacobian (forward-only, no W^T) + fa_grape: bool = False + fa_grape_n_probe: int = 32 # batch size for JVP probes + # Path IV: learned per-block residual gates. Each block: x + α_attn·attn(x) + α_mlp·mlp(x) + gated_blocks: bool = False # if True, add learnable scalar gate per (block, sublayer) + + +class LocalCausalSelfAttention(nn.Module): + def __init__(self, config: LocalGPTConfig): + super().__init__() + assert config.n_embd % config.n_head == 0 + self.n_head = config.n_head + self.n_embd = config.n_embd + self.head_dim = config.n_embd // config.n_head + self.block_size = config.block_size + self.attn_mode = config.attn_mode + self.ste_sigmoid = config.ste_sigmoid + self.fuse_attn_local = config.fuse_attn_local + + self.q_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None), + fa_grape=config.fa_grape, + fa_grape_n_probe=config.fa_grape_n_probe) + self.k_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None), + fa_grape=config.fa_grape, + fa_grape_n_probe=config.fa_grape_n_probe) + self.v_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None), + fa_grape=config.fa_grape, + fa_grape_n_probe=config.fa_grape_n_probe) + self.o_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None), + fa_grape=config.fa_grape, + fa_grape_n_probe=config.fa_grape_n_probe) + self.attn_drop = nn.Dropout(config.dropout) + self.resid_drop = nn.Dropout(config.dropout) + + causal = torch.tril(torch.ones(config.block_size, config.block_size, dtype=torch.bool)) + self.register_buffer("causal_mask", causal, persistent=False) + + if config.attn_mode == "sigmoid": + init_b = 0.0 if config.sigmoid_bias_mode == "zero" else -math.log(config.block_size) + if config.sigmoid_bias_mode == "learned": + self.sig_bias = nn.Parameter(torch.tensor(init_b)) + else: + self.register_buffer("sig_bias", torch.tensor(init_b), persistent=False) + + def forward(self, x): + B, T, C = x.shape + q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) + + scores = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) + mask = self.causal_mask[:T, :T] + scores = scores.masked_fill(~mask, float("-inf")) + + if self.fuse_attn_local and self.attn_mode == "softmax": + # Fused softmax+A@V with local backward: + # g_S_{i,j} = A_{ij} * <δO_i, V_j - O_i> (no lateral sum) + out = SoftmaxValueMixLocalFn.apply(scores, v) + else: + if self.attn_mode == "softmax": + attn = F.softmax(scores, dim=-1) + elif self.ste_sigmoid: + attn = SigmoidSTE.apply(scores + self.sig_bias) + else: + attn = torch.sigmoid(scores + self.sig_bias) + attn = self.attn_drop(attn) + out = attn @ v + + out = out.transpose(1, 2).contiguous().view(B, T, C) + return self.resid_drop(self.o_proj(out)) + + +class LocalMLP(nn.Module): + def __init__(self, config: LocalGPTConfig): + super().__init__() + self.fc = LocalLinear(config.n_embd, 4 * config.n_embd, bias=config.bias, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None), + fa_grape=config.fa_grape, + fa_grape_n_probe=config.fa_grape_n_probe) + self.proj = LocalLinear(4 * config.n_embd, config.n_embd, bias=config.bias, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None), + fa_grape=config.fa_grape, + fa_grape_n_probe=config.fa_grape_n_probe) + self.drop = nn.Dropout(config.dropout) + self.ste_gelu = config.ste_gelu + self.mlp_topk = config.mlp_topk + + def forward(self, x): + h = self.fc(x) + if self.ste_gelu: + h = GELUSTE.apply(h) + else: + h = F.gelu(h) + if self.mlp_topk > 0: + h = HardTopK.apply(h, self.mlp_topk) + return self.drop(self.proj(h)) + + +def _make_ln(config): + """Build the right LN variant based on config.ln_mode.""" + if config.ln_mode == "bp": + return nn.LayerNorm(config.n_embd) + if config.ln_mode == "ste": + return LayerNormSTE(config.n_embd) + if config.ln_mode in ("center_scale", "projected"): + return LayerNormProjectedSurrogate( + config.n_embd, mode=config.ln_mode, elementwise_affine=True, + ) + raise ValueError(f"Unknown ln_mode: {config.ln_mode}") + + +class LocalBlock(nn.Module): + def __init__(self, config: LocalGPTConfig): + super().__init__() + self.ln1 = _make_ln(config) + self.ln2 = _make_ln(config) + self.attn = LocalCausalSelfAttention(config) + self.mlp = LocalMLP(config) + self.resid_topk = config.resid_topk + self.vq = VQResidualDir(config.n_embd, config.vq_codes) if config.vq_codes > 0 else None + # FrozenSubspace uses fixed seed=42 so all blocks share the same Q (same subspace). + self.subspace = FrozenSubspace(config.n_embd, config.subspace_rank, seed=42) \ + if config.subspace_rank > 0 else None + # Path IV: per-sublayer learned residual gates. Init to 1.0 (no initial gating). + # If a sublayer is "noise net" under BPfree, its α can drive toward 0. + if config.gated_blocks: + self.alpha_attn = nn.Parameter(torch.ones(1)) + self.alpha_mlp = nn.Parameter(torch.ones(1)) + else: + self.alpha_attn = None + self.alpha_mlp = None + + def forward(self, x): + if self.alpha_attn is not None: + x = x + self.alpha_attn * self.attn(self.ln1(x)) + x = x + self.alpha_mlp * self.mlp(self.ln2(x)) + else: + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + if self.resid_topk > 0: + x = HardTopK.apply(x, self.resid_topk) + if self.vq is not None: + x = self.vq(x) + if self.subspace is not None: + x = self.subspace(x) + return x + + +class LocalGPT(nn.Module): + def __init__(self, config: LocalGPTConfig): + super().__init__() + self.config = config + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + self.pos_emb = nn.Embedding(config.block_size, config.n_embd) + if config.freeze_emb: + self.tok_emb.weight.requires_grad_(False) + self.pos_emb.weight.requires_grad_(False) + self.drop = nn.Dropout(config.dropout) + self.blocks = nn.ModuleList([LocalBlock(config) for _ in range(config.n_layer)]) + self.ln_f = _make_ln(config) + # Output head: also a LocalLinear (last linear layer before logits) + self.head = LocalLinear(config.n_embd, config.vocab_size, bias=False, method=config.method, + fa_init_mode=config.fa_init_mode, + fa_sparse_k=(config.fa_sparse_k or None)) + + self.apply(self._init_weights) + # Scale projection weights to reduce residual stream growth + for pn, p in self.named_parameters(): + if pn.endswith("o_proj.weight") or pn.endswith("mlp.proj.weight"): + nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) + + def _init_weights(self, m): + if isinstance(m, (nn.Linear, LocalLinear)): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + if getattr(m, "bias", None) is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + + def num_params(self) -> int: + return sum(p.numel() for p in self.parameters()) + + def forward(self, idx, targets=None): + B, T = idx.shape + assert T <= self.config.block_size + pos = torch.arange(T, device=idx.device) + x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) + for blk in self.blocks: + x = blk(x) + x = self.ln_f(x) + logits = self.head(x) + if targets is None: + return logits, None + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + return logits, loss + + @torch.no_grad() + def generate(self, idx, max_new_tokens: int, temperature: float = 1.0, top_k=None): + for _ in range(max_new_tokens): + idx_cond = idx[:, -self.config.block_size :] + logits, _ = self(idx_cond) + logits = logits[:, -1, :] / temperature + if top_k is not None: + v, _ = torch.topk(logits, top_k) + logits[logits < v[:, [-1]]] = -float("inf") + probs = F.softmax(logits, dim=-1) + nxt = torch.multinomial(probs, 1) + idx = torch.cat([idx, nxt], dim=1) + return idx |
