1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
"""
Minimal Vision Transformer for CIFAR-10. Pre-LN with terminal LayerNorm before
the classification head — the architecture P4 should target.
Designed to be compatible with the snapshot evolution / DFA training framework.
Each TransformerBlock is a "layer" for FA-style local credit purposes.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerBlock(nn.Module):
"""Pre-LN transformer block: x = x + attn(LN(x)); x = x + mlp(LN(x))."""
def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.ln2 = nn.LayerNorm(d_model)
mlp_hidden = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d_model, mlp_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden, d_model),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Self-attention sublayer
x_norm = self.ln1(x)
attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)
x = x + attn_out
# MLP sublayer
x = x + self.mlp(self.ln2(x))
return x
class ViTMini(nn.Module):
"""Minimal Vision Transformer for CIFAR-10.
Patch size 4x4 → 64 patches per image. Plus a learned cls token.
Pre-LN with terminal LayerNorm before the head.
"""
def __init__(
self,
image_size: int = 32,
patch_size: int = 4,
in_channels: int = 3,
num_classes: int = 10,
d_model: int = 128,
n_heads: int = 4,
num_blocks: int = 4,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
):
super().__init__()
assert image_size % patch_size == 0
n_patches = (image_size // patch_size) ** 2
self.n_tokens = n_patches + 1 # +1 for cls token
# Patch embedding via Conv2d (equivalent to flatten + linear)
self.patch_embed = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
self.pos_embed = nn.Parameter(torch.zeros(1, self.n_tokens, d_model))
nn.init.trunc_normal_(self.cls_token, std=0.02)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, mlp_ratio, dropout) for _ in range(num_blocks)
])
self.out_ln = nn.LayerNorm(d_model) # terminal LN — the P4 trigger
self.out_head = nn.Linear(d_model, num_classes)
self.num_blocks = num_blocks
self.d_model = d_model
self.d_hidden = d_model # alias for compatibility with snapshot script
def embed(self, x: torch.Tensor) -> torch.Tensor:
"""Take a flat-CIFAR input (B, 3072) or image (B, 3, 32, 32) → token sequence (B, 65, d_model)."""
if x.dim() == 2: # flat input
x = x.view(x.size(0), 3, 32, 32)
# x: (B, 3, 32, 32)
x = self.patch_embed(x) # (B, d, 8, 8)
x = x.flatten(2).transpose(1, 2) # (B, 64, d)
cls = self.cls_token.expand(x.size(0), -1, -1)
x = torch.cat([cls, x], dim=1) # (B, 65, d)
x = x + self.pos_embed
return x
def forward(self, x: torch.Tensor, return_hidden: bool = False):
h = self.embed(x) # (B, 65, d)
hiddens = [h] if return_hidden else None
for block in self.blocks:
h = block(h)
if return_hidden:
hiddens.append(h)
# Take cls token, normalize, classify
h_cls = self.out_ln(h[:, 0]) # (B, d)
logits = self.out_head(h_cls)
if return_hidden:
return logits, hiddens
return logits
def forward_from_layer(self, h: torch.Tensor, start_layer: int):
"""Run forward from a given block index. h has shape (B, n_tokens, d)."""
for i in range(start_layer, self.num_blocks):
h = self.blocks[i](h)
h_cls = self.out_ln(h[:, 0])
return self.out_head(h_cls)
|