diff options
Diffstat (limited to 'models/vit_mini.py')
| -rw-r--r-- | models/vit_mini.py | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/models/vit_mini.py b/models/vit_mini.py new file mode 100644 index 0000000..af6ba60 --- /dev/null +++ b/models/vit_mini.py @@ -0,0 +1,109 @@ +""" +Minimal Vision Transformer for CIFAR-10. Pre-LN with terminal LayerNorm before +the classification head — the architecture P4 should target. + +Designed to be compatible with the snapshot evolution / DFA training framework. +Each TransformerBlock is a "layer" for FA-style local credit purposes. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TransformerBlock(nn.Module): + """Pre-LN transformer block: x = x + attn(LN(x)); x = x + mlp(LN(x)).""" + def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0): + super().__init__() + self.ln1 = nn.LayerNorm(d_model) + self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) + self.ln2 = nn.LayerNorm(d_model) + mlp_hidden = int(d_model * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(d_model, mlp_hidden), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(mlp_hidden, d_model), + nn.Dropout(dropout), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Self-attention sublayer + x_norm = self.ln1(x) + attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False) + x = x + attn_out + # MLP sublayer + x = x + self.mlp(self.ln2(x)) + return x + + +class ViTMini(nn.Module): + """Minimal Vision Transformer for CIFAR-10. + Patch size 4x4 → 64 patches per image. Plus a learned cls token. + Pre-LN with terminal LayerNorm before the head. + """ + def __init__( + self, + image_size: int = 32, + patch_size: int = 4, + in_channels: int = 3, + num_classes: int = 10, + d_model: int = 128, + n_heads: int = 4, + num_blocks: int = 4, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + ): + super().__init__() + assert image_size % patch_size == 0 + n_patches = (image_size // patch_size) ** 2 + self.n_tokens = n_patches + 1 # +1 for cls token + + # Patch embedding via Conv2d (equivalent to flatten + linear) + self.patch_embed = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size) + self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) + self.pos_embed = nn.Parameter(torch.zeros(1, self.n_tokens, d_model)) + nn.init.trunc_normal_(self.cls_token, std=0.02) + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + self.blocks = nn.ModuleList([ + TransformerBlock(d_model, n_heads, mlp_ratio, dropout) for _ in range(num_blocks) + ]) + self.out_ln = nn.LayerNorm(d_model) # terminal LN — the P4 trigger + self.out_head = nn.Linear(d_model, num_classes) + + self.num_blocks = num_blocks + self.d_model = d_model + self.d_hidden = d_model # alias for compatibility with snapshot script + + def embed(self, x: torch.Tensor) -> torch.Tensor: + """Take a flat-CIFAR input (B, 3072) or image (B, 3, 32, 32) → token sequence (B, 65, d_model).""" + if x.dim() == 2: # flat input + x = x.view(x.size(0), 3, 32, 32) + # x: (B, 3, 32, 32) + x = self.patch_embed(x) # (B, d, 8, 8) + x = x.flatten(2).transpose(1, 2) # (B, 64, d) + cls = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat([cls, x], dim=1) # (B, 65, d) + x = x + self.pos_embed + return x + + def forward(self, x: torch.Tensor, return_hidden: bool = False): + h = self.embed(x) # (B, 65, d) + hiddens = [h] if return_hidden else None + for block in self.blocks: + h = block(h) + if return_hidden: + hiddens.append(h) + # Take cls token, normalize, classify + h_cls = self.out_ln(h[:, 0]) # (B, d) + logits = self.out_head(h_cls) + if return_hidden: + return logits, hiddens + return logits + + def forward_from_layer(self, h: torch.Tensor, start_layer: int): + """Run forward from a given block index. h has shape (B, n_tokens, d).""" + for i in range(start_layer, self.num_blocks): + h = self.blocks[i](h) + h_cls = self.out_ln(h[:, 0]) + return self.out_head(h_cls) |
