""" Minimal Vision Transformer for CIFAR-10. Pre-LN with terminal LayerNorm before the classification head — the architecture P4 should target. Designed to be compatible with the snapshot evolution / DFA training framework. Each TransformerBlock is a "layer" for FA-style local credit purposes. """ import torch import torch.nn as nn import torch.nn.functional as F class TransformerBlock(nn.Module): """Pre-LN transformer block: x = x + attn(LN(x)); x = x + mlp(LN(x)).""" def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) self.ln2 = nn.LayerNorm(d_model) mlp_hidden = int(d_model * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(d_model, mlp_hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden, d_model), nn.Dropout(dropout), ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Self-attention sublayer x_norm = self.ln1(x) attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False) x = x + attn_out # MLP sublayer x = x + self.mlp(self.ln2(x)) return x class ViTMini(nn.Module): """Minimal Vision Transformer for CIFAR-10. Patch size 4x4 → 64 patches per image. Plus a learned cls token. Pre-LN with terminal LayerNorm before the head. """ def __init__( self, image_size: int = 32, patch_size: int = 4, in_channels: int = 3, num_classes: int = 10, d_model: int = 128, n_heads: int = 4, num_blocks: int = 4, mlp_ratio: float = 4.0, dropout: float = 0.0, ): super().__init__() assert image_size % patch_size == 0 n_patches = (image_size // patch_size) ** 2 self.n_tokens = n_patches + 1 # +1 for cls token # Patch embedding via Conv2d (equivalent to flatten + linear) self.patch_embed = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size) self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) self.pos_embed = nn.Parameter(torch.zeros(1, self.n_tokens, d_model)) nn.init.trunc_normal_(self.cls_token, std=0.02) nn.init.trunc_normal_(self.pos_embed, std=0.02) self.blocks = nn.ModuleList([ TransformerBlock(d_model, n_heads, mlp_ratio, dropout) for _ in range(num_blocks) ]) self.out_ln = nn.LayerNorm(d_model) # terminal LN — the P4 trigger self.out_head = nn.Linear(d_model, num_classes) self.num_blocks = num_blocks self.d_model = d_model self.d_hidden = d_model # alias for compatibility with snapshot script def embed(self, x: torch.Tensor) -> torch.Tensor: """Take a flat-CIFAR input (B, 3072) or image (B, 3, 32, 32) → token sequence (B, 65, d_model).""" if x.dim() == 2: # flat input x = x.view(x.size(0), 3, 32, 32) # x: (B, 3, 32, 32) x = self.patch_embed(x) # (B, d, 8, 8) x = x.flatten(2).transpose(1, 2) # (B, 64, d) cls = self.cls_token.expand(x.size(0), -1, -1) x = torch.cat([cls, x], dim=1) # (B, 65, d) x = x + self.pos_embed return x def forward(self, x: torch.Tensor, return_hidden: bool = False): h = self.embed(x) # (B, 65, d) hiddens = [h] if return_hidden else None for block in self.blocks: h = block(h) if return_hidden: hiddens.append(h) # Take cls token, normalize, classify h_cls = self.out_ln(h[:, 0]) # (B, d) logits = self.out_head(h_cls) if return_hidden: return logits, hiddens return logits def forward_from_layer(self, h: torch.Tensor, start_layer: int): """Run forward from a given block index. h has shape (B, n_tokens, d).""" for i in range(start_layer, self.num_blocks): h = self.blocks[i](h) h_cls = self.out_ln(h[:, 0]) return self.out_head(h_cls)