summaryrefslogtreecommitdiff
path: root/ep_run/model.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/model.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/model.py')
-rw-r--r--ep_run/model.py156
1 files changed, 156 insertions, 0 deletions
diff --git a/ep_run/model.py b/ep_run/model.py
new file mode 100644
index 0000000..149724b
--- /dev/null
+++ b/ep_run/model.py
@@ -0,0 +1,156 @@
+"""Tiny GPT with switchable softmax/sigmoid causal self-attention.
+
+Architecture follows nanoGPT (Karpathy), trimmed to a single file for this
+Sigmoid Attention reproduction experiment (Ramapuram et al. 2024).
+"""
+import math
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+@dataclass
+class GPTConfig:
+ block_size: int = 256
+ vocab_size: int = 65
+ n_layer: int = 6
+ n_head: int = 6
+ n_embd: int = 384
+ dropout: float = 0.2
+ bias: bool = False # bias in linear layers
+ attn_mode: str = "softmax" # "softmax" or "sigmoid"
+ sigmoid_bias_mode: str = "neg_log_n" # "zero" | "neg_log_n" | "learned"
+
+
+class CausalSelfAttention(nn.Module):
+ def __init__(self, config: GPTConfig):
+ super().__init__()
+ assert config.n_embd % config.n_head == 0
+ self.n_head = config.n_head
+ self.n_embd = config.n_embd
+ self.head_dim = config.n_embd // config.n_head
+ self.block_size = config.block_size
+ self.attn_mode = config.attn_mode
+ self.sigmoid_bias_mode = config.sigmoid_bias_mode
+
+ self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
+ self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
+ self.attn_drop = nn.Dropout(config.dropout)
+ self.resid_drop = nn.Dropout(config.dropout)
+
+ causal = torch.tril(torch.ones(config.block_size, config.block_size, dtype=torch.bool))
+ self.register_buffer("causal_mask", causal, persistent=False)
+
+ if config.attn_mode == "sigmoid":
+ if config.sigmoid_bias_mode == "zero":
+ init_b = 0.0
+ else:
+ init_b = -math.log(config.block_size)
+ if config.sigmoid_bias_mode == "learned":
+ self.sig_bias = nn.Parameter(torch.tensor(init_b))
+ else:
+ self.register_buffer("sig_bias", torch.tensor(init_b), persistent=False)
+
+ def forward(self, x):
+ B, T, C = x.shape
+ q, k, v = self.qkv(x).split(self.n_embd, dim=-1)
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
+
+ scores = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
+ mask = self.causal_mask[:T, :T]
+ scores = scores.masked_fill(~mask, float("-inf"))
+
+ if self.attn_mode == "softmax":
+ attn = F.softmax(scores, dim=-1)
+ else:
+ # sigmoid(scores + b). masked -> sigmoid(-inf) = 0 naturally.
+ attn = torch.sigmoid(scores + self.sig_bias)
+
+ attn = self.attn_drop(attn)
+ out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
+ return self.resid_drop(self.proj(out))
+
+
+class MLP(nn.Module):
+ def __init__(self, config: GPTConfig):
+ super().__init__()
+ self.fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
+ self.proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
+ self.drop = nn.Dropout(config.dropout)
+
+ def forward(self, x):
+ return self.drop(self.proj(F.gelu(self.fc(x))))
+
+
+class Block(nn.Module):
+ def __init__(self, config: GPTConfig):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(config.n_embd)
+ self.attn = CausalSelfAttention(config)
+ self.ln2 = nn.LayerNorm(config.n_embd)
+ self.mlp = MLP(config)
+
+ def forward(self, x):
+ x = x + self.attn(self.ln1(x))
+ x = x + self.mlp(self.ln2(x))
+ return x
+
+
+class GPT(nn.Module):
+ def __init__(self, config: GPTConfig):
+ super().__init__()
+ self.config = config
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
+ self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
+ self.drop = nn.Dropout(config.dropout)
+ self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+
+ self.apply(self._init_weights)
+ for pn, p in self.named_parameters():
+ if pn.endswith("proj.weight"):
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Embedding):
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
+
+ def num_params(self) -> int:
+ return sum(p.numel() for p in self.parameters())
+
+ def forward(self, idx, targets=None):
+ B, T = idx.shape
+ assert T <= self.config.block_size, f"seq len {T} > block_size {self.config.block_size}"
+ pos = torch.arange(T, device=idx.device)
+ x = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.ln_f(x)
+ logits = self.head(x)
+ if targets is None:
+ return logits, None
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+ return logits, loss
+
+ @torch.no_grad()
+ def generate(self, idx, max_new_tokens: int, temperature: float = 1.0, top_k=None):
+ for _ in range(max_new_tokens):
+ idx_cond = idx[:, -self.config.block_size :]
+ logits, _ = self(idx_cond)
+ logits = logits[:, -1, :] / temperature
+ if top_k is not None:
+ v, _ = torch.topk(logits, top_k)
+ logits[logits < v[:, [-1]]] = -float("inf")
+ probs = F.softmax(logits, dim=-1)
+ nxt = torch.multinomial(probs, 1)
+ idx = torch.cat([idx, nxt], dim=1)
+ return idx