"""Sigmoid GPT with split Q/K/V projections and LocalLinear for method dispatch. Derived from model.py but uses LocalLinear for every linear layer and splits the fused qkv into separate q_proj, k_proj, v_proj so that each projection has its own feedback matrix for FA / DFA / sign_sym. """ import math from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from local_layers import LocalLinear class SigmoidSTE(torch.autograd.Function): """Sigmoid forward, straight-through backward (skip A(1-A) derivative).""" @staticmethod def forward(ctx, x): return torch.sigmoid(x) @staticmethod def backward(ctx, grad_out): return grad_out class GELUSTE(torch.autograd.Function): """GELU forward, straight-through backward (skip gelu' derivative).""" @staticmethod def forward(ctx, x): return F.gelu(x) @staticmethod def backward(ctx, grad_out): return grad_out class HardTopK(torch.autograd.Function): """k-WTA: zero out all but top-k (by abs value) along last dim, in BOTH forward and backward. Forward: keep top-k entries, zero rest. Backward: gradient mask = forward mask (only winners get gradient). This enforces strict sparsity — non-selected channels never update. """ @staticmethod def forward(ctx, x, k): topk_vals, topk_idx = x.abs().topk(k, dim=-1) mask = torch.zeros_like(x).scatter_(-1, topk_idx, 1.0) ctx.save_for_backward(mask) return x * mask @staticmethod def backward(ctx, grad_out): (mask,) = ctx.saved_tensors return grad_out * mask, None class FrozenSubspace(nn.Module): """Project h to fixed r-dim orthonormal subspace via Q Q^T h. Q ∈ R^{d × r} is a random orthonormal basis, frozen at init. Output lives in span(Q) ⊂ R^d. (d-r) directions are killed. With same seed across blocks, all layers share the same subspace — so the residual stream is constrained to span(Q) throughout the network. Differentiable (no STE): grad_h = Q Q^T grad_out (same projection). For BPfree: residual codebook subspace ≈ span(Q) is exactly what BPfree delivers gradient on, so feedback geometry is matched by construction. """ def __init__(self, d_model, rank, seed=42): super().__init__() self.rank = rank gen = torch.Generator() gen.manual_seed(seed) Q, _ = torch.linalg.qr(torch.randn(d_model, rank, generator=gen)) self.register_buffer("Q", Q) # (d, r) def forward(self, h): # h @ Q → (..., r) coefficients in basis # @ Q.t() → back to (..., d), now in span(Q) return h @ self.Q @ self.Q.t() class VQResidualDir(nn.Module): """Directional quantization to fixed codebook with STE backward. Forward: replace h's direction with nearest of K fixed unit-norm codebook entries (per token). Magnitude is preserved. Backward: identity through h (STE — no gradient on the codebook lookup itself). Codebook is initialized with random unit-norm directions and FROZEN (registered buffer). The "feature directions" are predefined — the network only learns *which code to land on* per token per layer. Discrete bottleneck: log2(K) bits per token per layer. For BPfree: the gradient signal needed to switch between codes is in {radial, low-rank residual} subspace, matching BPfree exit's bandwidth. """ def __init__(self, d_model, n_codes, seed=None): super().__init__() self.n_codes = n_codes gen = torch.Generator() if seed is not None: gen.manual_seed(seed) codes = torch.randn(n_codes, d_model, generator=gen) codes = codes / codes.norm(dim=-1, keepdim=True).clamp_min(1e-8) self.register_buffer("codebook", codes) def forward(self, h): h_norm = h.norm(dim=-1, keepdim=True).clamp_min(1e-8) h_hat = h / h_norm sims = h_hat @ self.codebook.t() # (..., K) idx = sims.argmax(dim=-1) z_q_dir = self.codebook[idx] # (..., d) unit-norm direction z_q = z_q_dir * h_norm # restore magnitude # STE: forward = z_q, backward = identity through h return h + (z_q - h).detach() class LayerNormSTE(nn.Module): """LayerNorm forward, straight-through backward (gradient passes through as identity).""" def __init__(self, normalized_shape): super().__init__() self.ln = nn.LayerNorm(normalized_shape) def forward(self, x): with torch.no_grad(): out = self.ln(x) return x + (out - x).detach() class _ProjectedSurrogateLNFn(torch.autograd.Function): """Core autograd function for projected surrogate LN backward. mode='projected': full P_z(v) = v - mean(v) - z*mean(v*z), scaled by 1/σ mode='center_scale': only v - mean(v), scaled by 1/σ (no radial removal) """ @staticmethod def forward(ctx, x, eps, mode): x_f = x.float() if x.dtype in (torch.float16, torch.bfloat16) else x mu = x_f.mean(dim=-1, keepdim=True) xc = x_f - mu var = (xc * xc).mean(dim=-1, keepdim=True) rsigma = torch.rsqrt(var + eps) z = xc * rsigma ctx.save_for_backward(z, rsigma) ctx.mode = mode ctx.input_dtype = x.dtype return z.to(dtype=x.dtype) @staticmethod def backward(ctx, g_tilde): z, rsigma = ctx.saved_tensors v = g_tilde.float() if g_tilde.dtype in (torch.float16, torch.bfloat16) else g_tilde v = v.to(dtype=z.dtype) v_mean = v.mean(dim=-1, keepdim=True) if ctx.mode == "projected": vz_mean = (v * z).mean(dim=-1, keepdim=True) p_v = v - v_mean - z * vz_mean else: # center_scale p_v = v - v_mean g_x = p_v * rsigma return g_x.to(dtype=ctx.input_dtype), None, None class LayerNormProjectedSurrogate(nn.Module): """LN forward = standard normalization. LN backward = projected surrogate (not BP). mode='projected': full mean-center + radial removal + 1/σ scaling mode='center_scale': mean-center + 1/σ only (no radial removal) Affine (γ, β) handled outside the custom Function so g̃ = ∂L/∂z exactly. """ def __init__(self, normalized_shape, eps=1e-5, mode="projected", elementwise_affine=False, bias=True): super().__init__() self.normalized_shape = normalized_shape self.eps = eps self.mode = mode if elementwise_affine: self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias_param = nn.Parameter(torch.zeros(normalized_shape)) if bias else None else: self.weight = None self.bias_param = None def forward(self, x): z = _ProjectedSurrogateLNFn.apply(x, self.eps, self.mode) if self.weight is not None: z = z * self.weight if self.bias_param is not None: z = z + self.bias_param return z class SoftmaxValueMixLocalFn(torch.autograd.Function): """Fused softmax(S) @ V with local backward. Forward: A = softmax(S), O = A @ V Backward: g_S_{i,j} = A_{ij} * <δO_i, V_j - O_i> (no lateral sum!) δV = A^T @ δO (attention-weighted gather) The softmax Jacobian's "lateral sum" Σ_j A_ij g_ij collapses to a per-query scalar baseline <δO_i, O_i> when composed with A@V — pure algebra, not approximation. """ @staticmethod def forward(ctx, scores, v): attn = F.softmax(scores, dim=-1) out = torch.einsum("bhtk,bhkd->bhtd", attn, v) ctx.save_for_backward(attn.detach(), out.detach(), v.detach()) return out @staticmethod def backward(ctx, delta_out): attn, out, v = ctx.saved_tensors # g_A_{i,j} = <δO_i, V_j> g_a = torch.einsum("bhtd,bhkd->bhtk", delta_out, v) # baseline = <δO_i, O_i> per query (the "lateral sum" collapsed to this) baseline = (delta_out * out).sum(dim=-1, keepdim=True) # g_S_{i,j} = A_{ij} * (<δO_i, V_j> - <δO_i, O_i>) g_scores = attn * (g_a - baseline) # δV = A^T @ δO (value gradient) delta_v = torch.einsum("bhtk,bhtd->bhkd", attn, delta_out) return g_scores, delta_v @dataclass class LocalGPTConfig: block_size: int = 256 vocab_size: int = 65 n_layer: int = 6 n_head: int = 6 n_embd: int = 384 dropout: float = 0.2 bias: bool = False attn_mode: str = "sigmoid" sigmoid_bias_mode: str = "neg_log_n" method: str = "bp" # bp | fa | dfa | sign_sym # STE ablation flags ste_sigmoid: bool = False # skip A(1-A) in sigmoid attention backward ste_gelu: bool = False # skip gelu' in FFN backward freeze_emb: bool = False # freeze token + position embeddings # LN backward mode: "bp" (standard), "ste" (identity), "center_scale", "projected" ln_mode: str = "bp" fuse_attn_local: bool = False # fuse softmax+A@V with local backward (no lateral sum) # Sparsity options for SparseFormer experiments mlp_topk: int = 0 # if > 0, apply hard top-k (k-WTA) to MLP hidden activation (4*n_embd dim) resid_topk: int = 0 # if > 0, apply hard top-k to residual stream output of each block (n_embd dim) # FrozenCodeFormer: directional VQ to fixed codebook at residual stream end vq_codes: int = 0 # if > 0, apply VQResidualDir with K=vq_codes fixed unit-norm codebook entries # FrozenSubspace: continuous r-dim subspace constraint (shared Q across all blocks) subspace_rank: int = 0 # if > 0, project residual stream to fixed r-dim subspace at each block # FA B-init mode (only used when method='fa'): gaussian | orthogonal | ortho_he | sparse fa_init_mode: str = "gaussian" fa_sparse_k: int = 0 # for fa_init_mode='sparse': non-zero entries per row (0 = auto in/16) # GrAPE: per-step JVP-based cosine alignment of B toward true Jacobian (forward-only, no W^T) fa_grape: bool = False fa_grape_n_probe: int = 32 # batch size for JVP probes # Path IV: learned per-block residual gates. Each block: x + α_attn·attn(x) + α_mlp·mlp(x) gated_blocks: bool = False # if True, add learnable scalar gate per (block, sublayer) class LocalCausalSelfAttention(nn.Module): def __init__(self, config: LocalGPTConfig): super().__init__() assert config.n_embd % config.n_head == 0 self.n_head = config.n_head self.n_embd = config.n_embd self.head_dim = config.n_embd // config.n_head self.block_size = config.block_size self.attn_mode = config.attn_mode self.ste_sigmoid = config.ste_sigmoid self.fuse_attn_local = config.fuse_attn_local self.q_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method, fa_init_mode=config.fa_init_mode, fa_sparse_k=(config.fa_sparse_k or None), fa_grape=config.fa_grape, fa_grape_n_probe=config.fa_grape_n_probe) self.k_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method, fa_init_mode=config.fa_init_mode, fa_sparse_k=(config.fa_sparse_k or None), fa_grape=config.fa_grape, fa_grape_n_probe=config.fa_grape_n_probe) self.v_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method, fa_init_mode=config.fa_init_mode, fa_sparse_k=(config.fa_sparse_k or None), fa_grape=config.fa_grape, fa_grape_n_probe=config.fa_grape_n_probe) self.o_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method, fa_init_mode=config.fa_init_mode, fa_sparse_k=(config.fa_sparse_k or None), fa_grape=config.fa_grape, fa_grape_n_probe=config.fa_grape_n_probe) self.attn_drop = nn.Dropout(config.dropout) self.resid_drop = nn.Dropout(config.dropout) causal = torch.tril(torch.ones(config.block_size, config.block_size, dtype=torch.bool)) self.register_buffer("causal_mask", causal, persistent=False) if config.attn_mode == "sigmoid": init_b = 0.0 if config.sigmoid_bias_mode == "zero" else -math.log(config.block_size) if config.sigmoid_bias_mode == "learned": self.sig_bias = nn.Parameter(torch.tensor(init_b)) else: self.register_buffer("sig_bias", torch.tensor(init_b), persistent=False) def forward(self, x): B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) scores = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) mask = self.causal_mask[:T, :T] scores = scores.masked_fill(~mask, float("-inf")) if self.fuse_attn_local and self.attn_mode == "softmax": # Fused softmax+A@V with local backward: # g_S_{i,j} = A_{ij} * <δO_i, V_j - O_i> (no lateral sum) out = SoftmaxValueMixLocalFn.apply(scores, v) else: if self.attn_mode == "softmax": attn = F.softmax(scores, dim=-1) elif self.ste_sigmoid: attn = SigmoidSTE.apply(scores + self.sig_bias) else: attn = torch.sigmoid(scores + self.sig_bias) attn = self.attn_drop(attn) out = attn @ v out = out.transpose(1, 2).contiguous().view(B, T, C) return self.resid_drop(self.o_proj(out)) class LocalMLP(nn.Module): def __init__(self, config: LocalGPTConfig): super().__init__() self.fc = LocalLinear(config.n_embd, 4 * config.n_embd, bias=config.bias, method=config.method, fa_init_mode=config.fa_init_mode, fa_sparse_k=(config.fa_sparse_k or None), fa_grape=config.fa_grape, fa_grape_n_probe=config.fa_grape_n_probe) self.proj = LocalLinear(4 * config.n_embd, config.n_embd, bias=config.bias, method=config.method, fa_init_mode=config.fa_init_mode, fa_sparse_k=(config.fa_sparse_k or None), fa_grape=config.fa_grape, fa_grape_n_probe=config.fa_grape_n_probe) self.drop = nn.Dropout(config.dropout) self.ste_gelu = config.ste_gelu self.mlp_topk = config.mlp_topk def forward(self, x): h = self.fc(x) if self.ste_gelu: h = GELUSTE.apply(h) else: h = F.gelu(h) if self.mlp_topk > 0: h = HardTopK.apply(h, self.mlp_topk) return self.drop(self.proj(h)) def _make_ln(config): """Build the right LN variant based on config.ln_mode.""" if config.ln_mode == "bp": return nn.LayerNorm(config.n_embd) if config.ln_mode == "ste": return LayerNormSTE(config.n_embd) if config.ln_mode in ("center_scale", "projected"): return LayerNormProjectedSurrogate( config.n_embd, mode=config.ln_mode, elementwise_affine=True, ) raise ValueError(f"Unknown ln_mode: {config.ln_mode}") class LocalBlock(nn.Module): def __init__(self, config: LocalGPTConfig): super().__init__() self.ln1 = _make_ln(config) self.ln2 = _make_ln(config) self.attn = LocalCausalSelfAttention(config) self.mlp = LocalMLP(config) self.resid_topk = config.resid_topk self.vq = VQResidualDir(config.n_embd, config.vq_codes) if config.vq_codes > 0 else None # FrozenSubspace uses fixed seed=42 so all blocks share the same Q (same subspace). self.subspace = FrozenSubspace(config.n_embd, config.subspace_rank, seed=42) \ if config.subspace_rank > 0 else None # Path IV: per-sublayer learned residual gates. Init to 1.0 (no initial gating). # If a sublayer is "noise net" under BPfree, its α can drive toward 0. if config.gated_blocks: self.alpha_attn = nn.Parameter(torch.ones(1)) self.alpha_mlp = nn.Parameter(torch.ones(1)) else: self.alpha_attn = None self.alpha_mlp = None def forward(self, x): if self.alpha_attn is not None: x = x + self.alpha_attn * self.attn(self.ln1(x)) x = x + self.alpha_mlp * self.mlp(self.ln2(x)) else: x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) if self.resid_topk > 0: x = HardTopK.apply(x, self.resid_topk) if self.vq is not None: x = self.vq(x) if self.subspace is not None: x = self.subspace(x) return x class LocalGPT(nn.Module): def __init__(self, config: LocalGPTConfig): super().__init__() self.config = config self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) self.pos_emb = nn.Embedding(config.block_size, config.n_embd) if config.freeze_emb: self.tok_emb.weight.requires_grad_(False) self.pos_emb.weight.requires_grad_(False) self.drop = nn.Dropout(config.dropout) self.blocks = nn.ModuleList([LocalBlock(config) for _ in range(config.n_layer)]) self.ln_f = _make_ln(config) # Output head: also a LocalLinear (last linear layer before logits) self.head = LocalLinear(config.n_embd, config.vocab_size, bias=False, method=config.method, fa_init_mode=config.fa_init_mode, fa_sparse_k=(config.fa_sparse_k or None)) self.apply(self._init_weights) # Scale projection weights to reduce residual stream growth for pn, p in self.named_parameters(): if pn.endswith("o_proj.weight") or pn.endswith("mlp.proj.weight"): nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) def _init_weights(self, m): if isinstance(m, (nn.Linear, LocalLinear)): nn.init.normal_(m.weight, mean=0.0, std=0.02) if getattr(m, "bias", None) is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) def forward(self, idx, targets=None): B, T = idx.shape assert T <= self.config.block_size pos = torch.arange(T, device=idx.device) x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) for blk in self.blocks: x = blk(x) x = self.ln_f(x) logits = self.head(x) if targets is None: return logits, None loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens: int, temperature: float = 1.0, top_k=None): for _ in range(max_new_tokens): idx_cond = idx[:, -self.config.block_size :] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float("inf") probs = F.softmax(logits, dim=-1) nxt = torch.multinomial(probs, 1) idx = torch.cat([idx, nxt], dim=1) return idx