summaryrefslogtreecommitdiff
path: root/models/vit_mini.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
commitb480d0cdc21f944e4adccf6e81cc939b0450c5e9 (patch)
treef0e6afb5b3d448d1d6c35d9622d22d63073ca9a7 /models/vit_mini.py
Initial submission code: FA evaluation protocol + reproduction scripts
Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'models/vit_mini.py')
-rw-r--r--models/vit_mini.py109
1 files changed, 109 insertions, 0 deletions
diff --git a/models/vit_mini.py b/models/vit_mini.py
new file mode 100644
index 0000000..af6ba60
--- /dev/null
+++ b/models/vit_mini.py
@@ -0,0 +1,109 @@
+"""
+Minimal Vision Transformer for CIFAR-10. Pre-LN with terminal LayerNorm before
+the classification head — the architecture P4 should target.
+
+Designed to be compatible with the snapshot evolution / DFA training framework.
+Each TransformerBlock is a "layer" for FA-style local credit purposes.
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class TransformerBlock(nn.Module):
+ """Pre-LN transformer block: x = x + attn(LN(x)); x = x + mlp(LN(x))."""
+ def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(d_model)
+ self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
+ self.ln2 = nn.LayerNorm(d_model)
+ mlp_hidden = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model, mlp_hidden),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(mlp_hidden, d_model),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Self-attention sublayer
+ x_norm = self.ln1(x)
+ attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)
+ x = x + attn_out
+ # MLP sublayer
+ x = x + self.mlp(self.ln2(x))
+ return x
+
+
+class ViTMini(nn.Module):
+ """Minimal Vision Transformer for CIFAR-10.
+ Patch size 4x4 → 64 patches per image. Plus a learned cls token.
+ Pre-LN with terminal LayerNorm before the head.
+ """
+ def __init__(
+ self,
+ image_size: int = 32,
+ patch_size: int = 4,
+ in_channels: int = 3,
+ num_classes: int = 10,
+ d_model: int = 128,
+ n_heads: int = 4,
+ num_blocks: int = 4,
+ mlp_ratio: float = 4.0,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+ assert image_size % patch_size == 0
+ n_patches = (image_size // patch_size) ** 2
+ self.n_tokens = n_patches + 1 # +1 for cls token
+
+ # Patch embedding via Conv2d (equivalent to flatten + linear)
+ self.patch_embed = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.n_tokens, d_model))
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
+
+ self.blocks = nn.ModuleList([
+ TransformerBlock(d_model, n_heads, mlp_ratio, dropout) for _ in range(num_blocks)
+ ])
+ self.out_ln = nn.LayerNorm(d_model) # terminal LN — the P4 trigger
+ self.out_head = nn.Linear(d_model, num_classes)
+
+ self.num_blocks = num_blocks
+ self.d_model = d_model
+ self.d_hidden = d_model # alias for compatibility with snapshot script
+
+ def embed(self, x: torch.Tensor) -> torch.Tensor:
+ """Take a flat-CIFAR input (B, 3072) or image (B, 3, 32, 32) → token sequence (B, 65, d_model)."""
+ if x.dim() == 2: # flat input
+ x = x.view(x.size(0), 3, 32, 32)
+ # x: (B, 3, 32, 32)
+ x = self.patch_embed(x) # (B, d, 8, 8)
+ x = x.flatten(2).transpose(1, 2) # (B, 64, d)
+ cls = self.cls_token.expand(x.size(0), -1, -1)
+ x = torch.cat([cls, x], dim=1) # (B, 65, d)
+ x = x + self.pos_embed
+ return x
+
+ def forward(self, x: torch.Tensor, return_hidden: bool = False):
+ h = self.embed(x) # (B, 65, d)
+ hiddens = [h] if return_hidden else None
+ for block in self.blocks:
+ h = block(h)
+ if return_hidden:
+ hiddens.append(h)
+ # Take cls token, normalize, classify
+ h_cls = self.out_ln(h[:, 0]) # (B, d)
+ logits = self.out_head(h_cls)
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ def forward_from_layer(self, h: torch.Tensor, start_layer: int):
+ """Run forward from a given block index. h has shape (B, n_tokens, d)."""
+ for i in range(start_layer, self.num_blocks):
+ h = self.blocks[i](h)
+ h_cls = self.out_ln(h[:, 0])
+ return self.out_head(h_cls)