summaryrefslogtreecommitdiff
path: root/ep_run/model_local.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/model_local.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/model_local.py')
-rw-r--r--ep_run/model_local.py470
1 files changed, 470 insertions, 0 deletions
diff --git a/ep_run/model_local.py b/ep_run/model_local.py
new file mode 100644
index 0000000..a84c692
--- /dev/null
+++ b/ep_run/model_local.py
@@ -0,0 +1,470 @@
+"""Sigmoid GPT with split Q/K/V projections and LocalLinear for method dispatch.
+
+Derived from model.py but uses LocalLinear for every linear layer and splits
+the fused qkv into separate q_proj, k_proj, v_proj so that each projection has
+its own feedback matrix for FA / DFA / sign_sym.
+"""
+import math
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from local_layers import LocalLinear
+
+
+class SigmoidSTE(torch.autograd.Function):
+ """Sigmoid forward, straight-through backward (skip A(1-A) derivative)."""
+ @staticmethod
+ def forward(ctx, x):
+ return torch.sigmoid(x)
+ @staticmethod
+ def backward(ctx, grad_out):
+ return grad_out
+
+
+class GELUSTE(torch.autograd.Function):
+ """GELU forward, straight-through backward (skip gelu' derivative)."""
+ @staticmethod
+ def forward(ctx, x):
+ return F.gelu(x)
+ @staticmethod
+ def backward(ctx, grad_out):
+ return grad_out
+
+
+class HardTopK(torch.autograd.Function):
+ """k-WTA: zero out all but top-k (by abs value) along last dim, in BOTH forward and backward.
+
+ Forward: keep top-k entries, zero rest.
+ Backward: gradient mask = forward mask (only winners get gradient).
+
+ This enforces strict sparsity — non-selected channels never update.
+ """
+ @staticmethod
+ def forward(ctx, x, k):
+ topk_vals, topk_idx = x.abs().topk(k, dim=-1)
+ mask = torch.zeros_like(x).scatter_(-1, topk_idx, 1.0)
+ ctx.save_for_backward(mask)
+ return x * mask
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ (mask,) = ctx.saved_tensors
+ return grad_out * mask, None
+
+
+class FrozenSubspace(nn.Module):
+ """Project h to fixed r-dim orthonormal subspace via Q Q^T h.
+
+ Q ∈ R^{d × r} is a random orthonormal basis, frozen at init.
+ Output lives in span(Q) ⊂ R^d. (d-r) directions are killed.
+
+ With same seed across blocks, all layers share the same subspace —
+ so the residual stream is constrained to span(Q) throughout the network.
+
+ Differentiable (no STE): grad_h = Q Q^T grad_out (same projection).
+ For BPfree: residual codebook subspace ≈ span(Q) is exactly what BPfree
+ delivers gradient on, so feedback geometry is matched by construction.
+ """
+ def __init__(self, d_model, rank, seed=42):
+ super().__init__()
+ self.rank = rank
+ gen = torch.Generator()
+ gen.manual_seed(seed)
+ Q, _ = torch.linalg.qr(torch.randn(d_model, rank, generator=gen))
+ self.register_buffer("Q", Q) # (d, r)
+
+ def forward(self, h):
+ # h @ Q → (..., r) coefficients in basis
+ # @ Q.t() → back to (..., d), now in span(Q)
+ return h @ self.Q @ self.Q.t()
+
+
+class VQResidualDir(nn.Module):
+ """Directional quantization to fixed codebook with STE backward.
+
+ Forward: replace h's direction with nearest of K fixed unit-norm codebook entries
+ (per token). Magnitude is preserved.
+ Backward: identity through h (STE — no gradient on the codebook lookup itself).
+
+ Codebook is initialized with random unit-norm directions and FROZEN (registered buffer).
+ The "feature directions" are predefined — the network only learns *which code to land on*
+ per token per layer. Discrete bottleneck: log2(K) bits per token per layer.
+
+ For BPfree: the gradient signal needed to switch between codes is in
+ {radial, low-rank residual} subspace, matching BPfree exit's bandwidth.
+ """
+ def __init__(self, d_model, n_codes, seed=None):
+ super().__init__()
+ self.n_codes = n_codes
+ gen = torch.Generator()
+ if seed is not None:
+ gen.manual_seed(seed)
+ codes = torch.randn(n_codes, d_model, generator=gen)
+ codes = codes / codes.norm(dim=-1, keepdim=True).clamp_min(1e-8)
+ self.register_buffer("codebook", codes)
+
+ def forward(self, h):
+ h_norm = h.norm(dim=-1, keepdim=True).clamp_min(1e-8)
+ h_hat = h / h_norm
+ sims = h_hat @ self.codebook.t() # (..., K)
+ idx = sims.argmax(dim=-1)
+ z_q_dir = self.codebook[idx] # (..., d) unit-norm direction
+ z_q = z_q_dir * h_norm # restore magnitude
+ # STE: forward = z_q, backward = identity through h
+ return h + (z_q - h).detach()
+
+
+class LayerNormSTE(nn.Module):
+ """LayerNorm forward, straight-through backward (gradient passes through as identity)."""
+ def __init__(self, normalized_shape):
+ super().__init__()
+ self.ln = nn.LayerNorm(normalized_shape)
+ def forward(self, x):
+ with torch.no_grad():
+ out = self.ln(x)
+ return x + (out - x).detach()
+
+
+class _ProjectedSurrogateLNFn(torch.autograd.Function):
+ """Core autograd function for projected surrogate LN backward.
+ mode='projected': full P_z(v) = v - mean(v) - z*mean(v*z), scaled by 1/σ
+ mode='center_scale': only v - mean(v), scaled by 1/σ (no radial removal)
+ """
+ @staticmethod
+ def forward(ctx, x, eps, mode):
+ x_f = x.float() if x.dtype in (torch.float16, torch.bfloat16) else x
+ mu = x_f.mean(dim=-1, keepdim=True)
+ xc = x_f - mu
+ var = (xc * xc).mean(dim=-1, keepdim=True)
+ rsigma = torch.rsqrt(var + eps)
+ z = xc * rsigma
+ ctx.save_for_backward(z, rsigma)
+ ctx.mode = mode
+ ctx.input_dtype = x.dtype
+ return z.to(dtype=x.dtype)
+
+ @staticmethod
+ def backward(ctx, g_tilde):
+ z, rsigma = ctx.saved_tensors
+ v = g_tilde.float() if g_tilde.dtype in (torch.float16, torch.bfloat16) else g_tilde
+ v = v.to(dtype=z.dtype)
+ v_mean = v.mean(dim=-1, keepdim=True)
+ if ctx.mode == "projected":
+ vz_mean = (v * z).mean(dim=-1, keepdim=True)
+ p_v = v - v_mean - z * vz_mean
+ else: # center_scale
+ p_v = v - v_mean
+ g_x = p_v * rsigma
+ return g_x.to(dtype=ctx.input_dtype), None, None
+
+
+class LayerNormProjectedSurrogate(nn.Module):
+ """LN forward = standard normalization. LN backward = projected surrogate (not BP).
+ mode='projected': full mean-center + radial removal + 1/σ scaling
+ mode='center_scale': mean-center + 1/σ only (no radial removal)
+ Affine (γ, β) handled outside the custom Function so g̃ = ∂L/∂z exactly.
+ """
+ def __init__(self, normalized_shape, eps=1e-5, mode="projected",
+ elementwise_affine=False, bias=True):
+ super().__init__()
+ self.normalized_shape = normalized_shape
+ self.eps = eps
+ self.mode = mode
+ if elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias_param = nn.Parameter(torch.zeros(normalized_shape)) if bias else None
+ else:
+ self.weight = None
+ self.bias_param = None
+
+ def forward(self, x):
+ z = _ProjectedSurrogateLNFn.apply(x, self.eps, self.mode)
+ if self.weight is not None:
+ z = z * self.weight
+ if self.bias_param is not None:
+ z = z + self.bias_param
+ return z
+
+
+class SoftmaxValueMixLocalFn(torch.autograd.Function):
+ """Fused softmax(S) @ V with local backward.
+
+ Forward: A = softmax(S), O = A @ V
+ Backward: g_S_{i,j} = A_{ij} * <δO_i, V_j - O_i> (no lateral sum!)
+ δV = A^T @ δO (attention-weighted gather)
+
+ The softmax Jacobian's "lateral sum" Σ_j A_ij g_ij collapses to a per-query
+ scalar baseline <δO_i, O_i> when composed with A@V — pure algebra, not approximation.
+ """
+ @staticmethod
+ def forward(ctx, scores, v):
+ attn = F.softmax(scores, dim=-1)
+ out = torch.einsum("bhtk,bhkd->bhtd", attn, v)
+ ctx.save_for_backward(attn.detach(), out.detach(), v.detach())
+ return out
+
+ @staticmethod
+ def backward(ctx, delta_out):
+ attn, out, v = ctx.saved_tensors
+ # g_A_{i,j} = <δO_i, V_j>
+ g_a = torch.einsum("bhtd,bhkd->bhtk", delta_out, v)
+ # baseline = <δO_i, O_i> per query (the "lateral sum" collapsed to this)
+ baseline = (delta_out * out).sum(dim=-1, keepdim=True)
+ # g_S_{i,j} = A_{ij} * (<δO_i, V_j> - <δO_i, O_i>)
+ g_scores = attn * (g_a - baseline)
+ # δV = A^T @ δO (value gradient)
+ delta_v = torch.einsum("bhtk,bhtd->bhkd", attn, delta_out)
+ return g_scores, delta_v
+
+
+@dataclass
+class LocalGPTConfig:
+ block_size: int = 256
+ vocab_size: int = 65
+ n_layer: int = 6
+ n_head: int = 6
+ n_embd: int = 384
+ dropout: float = 0.2
+ bias: bool = False
+ attn_mode: str = "sigmoid"
+ sigmoid_bias_mode: str = "neg_log_n"
+ method: str = "bp" # bp | fa | dfa | sign_sym
+ # STE ablation flags
+ ste_sigmoid: bool = False # skip A(1-A) in sigmoid attention backward
+ ste_gelu: bool = False # skip gelu' in FFN backward
+ freeze_emb: bool = False # freeze token + position embeddings
+ # LN backward mode: "bp" (standard), "ste" (identity), "center_scale", "projected"
+ ln_mode: str = "bp"
+ fuse_attn_local: bool = False # fuse softmax+A@V with local backward (no lateral sum)
+ # Sparsity options for SparseFormer experiments
+ mlp_topk: int = 0 # if > 0, apply hard top-k (k-WTA) to MLP hidden activation (4*n_embd dim)
+ resid_topk: int = 0 # if > 0, apply hard top-k to residual stream output of each block (n_embd dim)
+ # FrozenCodeFormer: directional VQ to fixed codebook at residual stream end
+ vq_codes: int = 0 # if > 0, apply VQResidualDir with K=vq_codes fixed unit-norm codebook entries
+ # FrozenSubspace: continuous r-dim subspace constraint (shared Q across all blocks)
+ subspace_rank: int = 0 # if > 0, project residual stream to fixed r-dim subspace at each block
+ # FA B-init mode (only used when method='fa'): gaussian | orthogonal | ortho_he | sparse
+ fa_init_mode: str = "gaussian"
+ fa_sparse_k: int = 0 # for fa_init_mode='sparse': non-zero entries per row (0 = auto in/16)
+ # GrAPE: per-step JVP-based cosine alignment of B toward true Jacobian (forward-only, no W^T)
+ fa_grape: bool = False
+ fa_grape_n_probe: int = 32 # batch size for JVP probes
+ # Path IV: learned per-block residual gates. Each block: x + α_attn·attn(x) + α_mlp·mlp(x)
+ gated_blocks: bool = False # if True, add learnable scalar gate per (block, sublayer)
+
+
+class LocalCausalSelfAttention(nn.Module):
+ def __init__(self, config: LocalGPTConfig):
+ super().__init__()
+ assert config.n_embd % config.n_head == 0
+ self.n_head = config.n_head
+ self.n_embd = config.n_embd
+ self.head_dim = config.n_embd // config.n_head
+ self.block_size = config.block_size
+ self.attn_mode = config.attn_mode
+ self.ste_sigmoid = config.ste_sigmoid
+ self.fuse_attn_local = config.fuse_attn_local
+
+ self.q_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method,
+ fa_init_mode=config.fa_init_mode,
+ fa_sparse_k=(config.fa_sparse_k or None),
+ fa_grape=config.fa_grape,
+ fa_grape_n_probe=config.fa_grape_n_probe)
+ self.k_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method,
+ fa_init_mode=config.fa_init_mode,
+ fa_sparse_k=(config.fa_sparse_k or None),
+ fa_grape=config.fa_grape,
+ fa_grape_n_probe=config.fa_grape_n_probe)
+ self.v_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method,
+ fa_init_mode=config.fa_init_mode,
+ fa_sparse_k=(config.fa_sparse_k or None),
+ fa_grape=config.fa_grape,
+ fa_grape_n_probe=config.fa_grape_n_probe)
+ self.o_proj = LocalLinear(config.n_embd, config.n_embd, bias=config.bias, method=config.method,
+ fa_init_mode=config.fa_init_mode,
+ fa_sparse_k=(config.fa_sparse_k or None),
+ fa_grape=config.fa_grape,
+ fa_grape_n_probe=config.fa_grape_n_probe)
+ self.attn_drop = nn.Dropout(config.dropout)
+ self.resid_drop = nn.Dropout(config.dropout)
+
+ causal = torch.tril(torch.ones(config.block_size, config.block_size, dtype=torch.bool))
+ self.register_buffer("causal_mask", causal, persistent=False)
+
+ if config.attn_mode == "sigmoid":
+ init_b = 0.0 if config.sigmoid_bias_mode == "zero" else -math.log(config.block_size)
+ if config.sigmoid_bias_mode == "learned":
+ self.sig_bias = nn.Parameter(torch.tensor(init_b))
+ else:
+ self.register_buffer("sig_bias", torch.tensor(init_b), persistent=False)
+
+ def forward(self, x):
+ B, T, C = x.shape
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
+ k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
+ v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
+
+ scores = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
+ mask = self.causal_mask[:T, :T]
+ scores = scores.masked_fill(~mask, float("-inf"))
+
+ if self.fuse_attn_local and self.attn_mode == "softmax":
+ # Fused softmax+A@V with local backward:
+ # g_S_{i,j} = A_{ij} * <δO_i, V_j - O_i> (no lateral sum)
+ out = SoftmaxValueMixLocalFn.apply(scores, v)
+ else:
+ if self.attn_mode == "softmax":
+ attn = F.softmax(scores, dim=-1)
+ elif self.ste_sigmoid:
+ attn = SigmoidSTE.apply(scores + self.sig_bias)
+ else:
+ attn = torch.sigmoid(scores + self.sig_bias)
+ attn = self.attn_drop(attn)
+ out = attn @ v
+
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
+ return self.resid_drop(self.o_proj(out))
+
+
+class LocalMLP(nn.Module):
+ def __init__(self, config: LocalGPTConfig):
+ super().__init__()
+ self.fc = LocalLinear(config.n_embd, 4 * config.n_embd, bias=config.bias, method=config.method,
+ fa_init_mode=config.fa_init_mode,
+ fa_sparse_k=(config.fa_sparse_k or None),
+ fa_grape=config.fa_grape,
+ fa_grape_n_probe=config.fa_grape_n_probe)
+ self.proj = LocalLinear(4 * config.n_embd, config.n_embd, bias=config.bias, method=config.method,
+ fa_init_mode=config.fa_init_mode,
+ fa_sparse_k=(config.fa_sparse_k or None),
+ fa_grape=config.fa_grape,
+ fa_grape_n_probe=config.fa_grape_n_probe)
+ self.drop = nn.Dropout(config.dropout)
+ self.ste_gelu = config.ste_gelu
+ self.mlp_topk = config.mlp_topk
+
+ def forward(self, x):
+ h = self.fc(x)
+ if self.ste_gelu:
+ h = GELUSTE.apply(h)
+ else:
+ h = F.gelu(h)
+ if self.mlp_topk > 0:
+ h = HardTopK.apply(h, self.mlp_topk)
+ return self.drop(self.proj(h))
+
+
+def _make_ln(config):
+ """Build the right LN variant based on config.ln_mode."""
+ if config.ln_mode == "bp":
+ return nn.LayerNorm(config.n_embd)
+ if config.ln_mode == "ste":
+ return LayerNormSTE(config.n_embd)
+ if config.ln_mode in ("center_scale", "projected"):
+ return LayerNormProjectedSurrogate(
+ config.n_embd, mode=config.ln_mode, elementwise_affine=True,
+ )
+ raise ValueError(f"Unknown ln_mode: {config.ln_mode}")
+
+
+class LocalBlock(nn.Module):
+ def __init__(self, config: LocalGPTConfig):
+ super().__init__()
+ self.ln1 = _make_ln(config)
+ self.ln2 = _make_ln(config)
+ self.attn = LocalCausalSelfAttention(config)
+ self.mlp = LocalMLP(config)
+ self.resid_topk = config.resid_topk
+ self.vq = VQResidualDir(config.n_embd, config.vq_codes) if config.vq_codes > 0 else None
+ # FrozenSubspace uses fixed seed=42 so all blocks share the same Q (same subspace).
+ self.subspace = FrozenSubspace(config.n_embd, config.subspace_rank, seed=42) \
+ if config.subspace_rank > 0 else None
+ # Path IV: per-sublayer learned residual gates. Init to 1.0 (no initial gating).
+ # If a sublayer is "noise net" under BPfree, its α can drive toward 0.
+ if config.gated_blocks:
+ self.alpha_attn = nn.Parameter(torch.ones(1))
+ self.alpha_mlp = nn.Parameter(torch.ones(1))
+ else:
+ self.alpha_attn = None
+ self.alpha_mlp = None
+
+ def forward(self, x):
+ if self.alpha_attn is not None:
+ x = x + self.alpha_attn * self.attn(self.ln1(x))
+ x = x + self.alpha_mlp * self.mlp(self.ln2(x))
+ else:
+ x = x + self.attn(self.ln1(x))
+ x = x + self.mlp(self.ln2(x))
+ if self.resid_topk > 0:
+ x = HardTopK.apply(x, self.resid_topk)
+ if self.vq is not None:
+ x = self.vq(x)
+ if self.subspace is not None:
+ x = self.subspace(x)
+ return x
+
+
+class LocalGPT(nn.Module):
+ def __init__(self, config: LocalGPTConfig):
+ super().__init__()
+ self.config = config
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
+ self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
+ if config.freeze_emb:
+ self.tok_emb.weight.requires_grad_(False)
+ self.pos_emb.weight.requires_grad_(False)
+ self.drop = nn.Dropout(config.dropout)
+ self.blocks = nn.ModuleList([LocalBlock(config) for _ in range(config.n_layer)])
+ self.ln_f = _make_ln(config)
+ # Output head: also a LocalLinear (last linear layer before logits)
+ self.head = LocalLinear(config.n_embd, config.vocab_size, bias=False, method=config.method,
+ fa_init_mode=config.fa_init_mode,
+ fa_sparse_k=(config.fa_sparse_k or None))
+
+ self.apply(self._init_weights)
+ # Scale projection weights to reduce residual stream growth
+ for pn, p in self.named_parameters():
+ if pn.endswith("o_proj.weight") or pn.endswith("mlp.proj.weight"):
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Linear, LocalLinear)):
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
+ if getattr(m, "bias", None) is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Embedding):
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
+
+ def num_params(self) -> int:
+ return sum(p.numel() for p in self.parameters())
+
+ def forward(self, idx, targets=None):
+ B, T = idx.shape
+ assert T <= self.config.block_size
+ pos = torch.arange(T, device=idx.device)
+ x = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.ln_f(x)
+ logits = self.head(x)
+ if targets is None:
+ return logits, None
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+ return logits, loss
+
+ @torch.no_grad()
+ def generate(self, idx, max_new_tokens: int, temperature: float = 1.0, top_k=None):
+ for _ in range(max_new_tokens):
+ idx_cond = idx[:, -self.config.block_size :]
+ logits, _ = self(idx_cond)
+ logits = logits[:, -1, :] / temperature
+ if top_k is not None:
+ v, _ = torch.topk(logits, top_k)
+ logits[logits < v[:, [-1]]] = -float("inf")
+ probs = F.softmax(logits, dim=-1)
+ nxt = torch.multinomial(probs, 1)
+ idx = torch.cat([idx, nxt], dim=1)
+ return idx