"""Tiny GPT with switchable softmax/sigmoid causal self-attention. Architecture follows nanoGPT (Karpathy), trimmed to a single file for this Sigmoid Attention reproduction experiment (Ramapuram et al. 2024). """ import math from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F @dataclass class GPTConfig: block_size: int = 256 vocab_size: int = 65 n_layer: int = 6 n_head: int = 6 n_embd: int = 384 dropout: float = 0.2 bias: bool = False # bias in linear layers attn_mode: str = "softmax" # "softmax" or "sigmoid" sigmoid_bias_mode: str = "neg_log_n" # "zero" | "neg_log_n" | "learned" class CausalSelfAttention(nn.Module): def __init__(self, config: GPTConfig): super().__init__() assert config.n_embd % config.n_head == 0 self.n_head = config.n_head self.n_embd = config.n_embd self.head_dim = config.n_embd // config.n_head self.block_size = config.block_size self.attn_mode = config.attn_mode self.sigmoid_bias_mode = config.sigmoid_bias_mode self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.attn_drop = nn.Dropout(config.dropout) self.resid_drop = nn.Dropout(config.dropout) causal = torch.tril(torch.ones(config.block_size, config.block_size, dtype=torch.bool)) self.register_buffer("causal_mask", causal, persistent=False) if config.attn_mode == "sigmoid": if config.sigmoid_bias_mode == "zero": init_b = 0.0 else: init_b = -math.log(config.block_size) if config.sigmoid_bias_mode == "learned": self.sig_bias = nn.Parameter(torch.tensor(init_b)) else: self.register_buffer("sig_bias", torch.tensor(init_b), persistent=False) def forward(self, x): B, T, C = x.shape q, k, v = self.qkv(x).split(self.n_embd, dim=-1) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) scores = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) mask = self.causal_mask[:T, :T] scores = scores.masked_fill(~mask, float("-inf")) if self.attn_mode == "softmax": attn = F.softmax(scores, dim=-1) else: # sigmoid(scores + b). masked -> sigmoid(-inf) = 0 naturally. attn = torch.sigmoid(scores + self.sig_bias) attn = self.attn_drop(attn) out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C) return self.resid_drop(self.proj(out)) class MLP(nn.Module): def __init__(self, config: GPTConfig): super().__init__() self.fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) self.proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) self.drop = nn.Dropout(config.dropout) def forward(self, x): return self.drop(self.proj(F.gelu(self.fc(x)))) class Block(nn.Module): def __init__(self, config: GPTConfig): super().__init__() self.ln1 = nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) self.ln2 = nn.LayerNorm(config.n_embd) self.mlp = MLP(config) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class GPT(nn.Module): def __init__(self, config: GPTConfig): super().__init__() self.config = config self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) self.pos_emb = nn.Embedding(config.block_size, config.n_embd) self.drop = nn.Dropout(config.dropout) self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith("proj.weight"): nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) def forward(self, idx, targets=None): B, T = idx.shape assert T <= self.config.block_size, f"seq len {T} > block_size {self.config.block_size}" pos = torch.arange(T, device=idx.device) x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) for blk in self.blocks: x = blk(x) x = self.ln_f(x) logits = self.head(x) if targets is None: return logits, None loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens: int, temperature: float = 1.0, top_k=None): for _ in range(max_new_tokens): idx_cond = idx[:, -self.config.block_size :] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float("inf") probs = F.softmax(logits, dim=-1) nxt = torch.multinomial(probs, 1) idx = torch.cat([idx, nxt], dim=1) return idx