diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/model.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/model.py')
| -rw-r--r-- | ep_run/model.py | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/ep_run/model.py b/ep_run/model.py new file mode 100644 index 0000000..149724b --- /dev/null +++ b/ep_run/model.py @@ -0,0 +1,156 @@ +"""Tiny GPT with switchable softmax/sigmoid causal self-attention. + +Architecture follows nanoGPT (Karpathy), trimmed to a single file for this +Sigmoid Attention reproduction experiment (Ramapuram et al. 2024). +""" +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +@dataclass +class GPTConfig: + block_size: int = 256 + vocab_size: int = 65 + n_layer: int = 6 + n_head: int = 6 + n_embd: int = 384 + dropout: float = 0.2 + bias: bool = False # bias in linear layers + attn_mode: str = "softmax" # "softmax" or "sigmoid" + sigmoid_bias_mode: str = "neg_log_n" # "zero" | "neg_log_n" | "learned" + + +class CausalSelfAttention(nn.Module): + def __init__(self, config: GPTConfig): + super().__init__() + assert config.n_embd % config.n_head == 0 + self.n_head = config.n_head + self.n_embd = config.n_embd + self.head_dim = config.n_embd // config.n_head + self.block_size = config.block_size + self.attn_mode = config.attn_mode + self.sigmoid_bias_mode = config.sigmoid_bias_mode + + self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.attn_drop = nn.Dropout(config.dropout) + self.resid_drop = nn.Dropout(config.dropout) + + causal = torch.tril(torch.ones(config.block_size, config.block_size, dtype=torch.bool)) + self.register_buffer("causal_mask", causal, persistent=False) + + if config.attn_mode == "sigmoid": + if config.sigmoid_bias_mode == "zero": + init_b = 0.0 + else: + init_b = -math.log(config.block_size) + if config.sigmoid_bias_mode == "learned": + self.sig_bias = nn.Parameter(torch.tensor(init_b)) + else: + self.register_buffer("sig_bias", torch.tensor(init_b), persistent=False) + + def forward(self, x): + B, T, C = x.shape + q, k, v = self.qkv(x).split(self.n_embd, dim=-1) + q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) + k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) + v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) + + scores = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) + mask = self.causal_mask[:T, :T] + scores = scores.masked_fill(~mask, float("-inf")) + + if self.attn_mode == "softmax": + attn = F.softmax(scores, dim=-1) + else: + # sigmoid(scores + b). masked -> sigmoid(-inf) = 0 naturally. + attn = torch.sigmoid(scores + self.sig_bias) + + attn = self.attn_drop(attn) + out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C) + return self.resid_drop(self.proj(out)) + + +class MLP(nn.Module): + def __init__(self, config: GPTConfig): + super().__init__() + self.fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.drop = nn.Dropout(config.dropout) + + def forward(self, x): + return self.drop(self.proj(F.gelu(self.fc(x)))) + + +class Block(nn.Module): + def __init__(self, config: GPTConfig): + super().__init__() + self.ln1 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.ln2 = nn.LayerNorm(config.n_embd) + self.mlp = MLP(config) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + + +class GPT(nn.Module): + def __init__(self, config: GPTConfig): + super().__init__() + self.config = config + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + self.pos_emb = nn.Embedding(config.block_size, config.n_embd) + self.drop = nn.Dropout(config.dropout) + self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith("proj.weight"): + nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0.0, std=0.02) + + def num_params(self) -> int: + return sum(p.numel() for p in self.parameters()) + + def forward(self, idx, targets=None): + B, T = idx.shape + assert T <= self.config.block_size, f"seq len {T} > block_size {self.config.block_size}" + pos = torch.arange(T, device=idx.device) + x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) + for blk in self.blocks: + x = blk(x) + x = self.ln_f(x) + logits = self.head(x) + if targets is None: + return logits, None + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + return logits, loss + + @torch.no_grad() + def generate(self, idx, max_new_tokens: int, temperature: float = 1.0, top_k=None): + for _ in range(max_new_tokens): + idx_cond = idx[:, -self.config.block_size :] + logits, _ = self(idx_cond) + logits = logits[:, -1, :] / temperature + if top_k is not None: + v, _ = torch.topk(logits, top_k) + logits[logits < v[:, [-1]]] = -float("inf") + probs = F.softmax(logits, dim=-1) + nxt = torch.multinomial(probs, 1) + idx = torch.cat([idx, nxt], dim=1) + return idx |
