summaryrefslogtreecommitdiff
path: root/models/vit_mini.py
blob: af6ba60287b48cfa127de10a37fb2c8837793463 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Minimal Vision Transformer for CIFAR-10. Pre-LN with terminal LayerNorm before
the classification head — the architecture P4 should target.

Designed to be compatible with the snapshot evolution / DFA training framework.
Each TransformerBlock is a "layer" for FA-style local credit purposes.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class TransformerBlock(nn.Module):
    """Pre-LN transformer block: x = x + attn(LN(x)); x = x + mlp(LN(x))."""
    def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        mlp_hidden = int(d_model * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, mlp_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Self-attention sublayer
        x_norm = self.ln1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)
        x = x + attn_out
        # MLP sublayer
        x = x + self.mlp(self.ln2(x))
        return x


class ViTMini(nn.Module):
    """Minimal Vision Transformer for CIFAR-10.
    Patch size 4x4 → 64 patches per image. Plus a learned cls token.
    Pre-LN with terminal LayerNorm before the head.
    """
    def __init__(
        self,
        image_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 3,
        num_classes: int = 10,
        d_model: int = 128,
        n_heads: int = 4,
        num_blocks: int = 4,
        mlp_ratio: float = 4.0,
        dropout: float = 0.0,
    ):
        super().__init__()
        assert image_size % patch_size == 0
        n_patches = (image_size // patch_size) ** 2
        self.n_tokens = n_patches + 1  # +1 for cls token

        # Patch embedding via Conv2d (equivalent to flatten + linear)
        self.patch_embed = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.n_tokens, d_model))
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, mlp_ratio, dropout) for _ in range(num_blocks)
        ])
        self.out_ln = nn.LayerNorm(d_model)  # terminal LN — the P4 trigger
        self.out_head = nn.Linear(d_model, num_classes)

        self.num_blocks = num_blocks
        self.d_model = d_model
        self.d_hidden = d_model  # alias for compatibility with snapshot script

    def embed(self, x: torch.Tensor) -> torch.Tensor:
        """Take a flat-CIFAR input (B, 3072) or image (B, 3, 32, 32) → token sequence (B, 65, d_model)."""
        if x.dim() == 2:  # flat input
            x = x.view(x.size(0), 3, 32, 32)
        # x: (B, 3, 32, 32)
        x = self.patch_embed(x)             # (B, d, 8, 8)
        x = x.flatten(2).transpose(1, 2)    # (B, 64, d)
        cls = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls, x], dim=1)      # (B, 65, d)
        x = x + self.pos_embed
        return x

    def forward(self, x: torch.Tensor, return_hidden: bool = False):
        h = self.embed(x)                   # (B, 65, d)
        hiddens = [h] if return_hidden else None
        for block in self.blocks:
            h = block(h)
            if return_hidden:
                hiddens.append(h)
        # Take cls token, normalize, classify
        h_cls = self.out_ln(h[:, 0])         # (B, d)
        logits = self.out_head(h_cls)
        if return_hidden:
            return logits, hiddens
        return logits

    def forward_from_layer(self, h: torch.Tensor, start_layer: int):
        """Run forward from a given block index. h has shape (B, n_tokens, d)."""
        for i in range(start_layer, self.num_blocks):
            h = self.blocks[i](h)
        h_cls = self.out_ln(h[:, 0])
        return self.out_head(h_cls)