summaryrefslogtreecommitdiff
path: root/experiments/vit_frozen_blocks_baseline.py
blob: 8b53198559f6995c215ea7a9520f5e0d791ee36d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""
Frozen-random-blocks baseline for ViT-Mini: train BP and DFA where the 4
transformer blocks are randomly initialized and FROZEN (no parameter updates).
Only patch_embed + cls_token + pos_embed + out_ln + out_head are trainable.

This is the codex-round-6 control for the "DFA actually trains the transformer
blocks" claim. If frozen-blocks DFA gets ≈ 24% (matching the trainable-blocks
4-block ViT-Mini DFA acc), then the blocks are passengers — DFA's "24%" is
coming from patch_embed + head learning routed via untrained block mixing.
If frozen-blocks DFA stays much lower than 24%, then the trainable blocks
are doing learned work.

Usage:
    CUDA_VISIBLE_DEVICES=2 python experiments/vit_frozen_blocks_baseline.py
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np

from models.vit_mini import ViTMini


def get_loaders(batch_size=128):
    tv_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    tv = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
    te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
    return (
        DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
        DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
    )


def evaluate(model, loader, dev):
    model.eval()
    n = c = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(dev), y.to(dev)
            preds = model(x).argmax(-1)
            c += (preds == y).sum().item()
            n += x.size(0)
    return c / n


def freeze_blocks(model):
    for p in model.blocks.parameters():
        p.requires_grad_(False)
    model.blocks.eval()


def train_bp_frozen(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05):
    torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
    m = ViTMini(num_blocks=4, d_model=128, n_heads=4).to(dev)
    freeze_blocks(m)
    n_trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in m.parameters())
    print(f"BP-frozen-blocks: {n_trainable}/{n_total} params trainable", flush=True)
    opt = optim.AdamW(filter(lambda p: p.requires_grad, m.parameters()), lr=lr, weight_decay=wd)
    sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    for ep in range(1, epochs + 1):
        m.train()
        m.blocks.eval()  # keep blocks in eval mode (no dropout etc)
        for x, y in train_loader:
            x = x.to(dev); y = y.to(dev)
            loss = F.cross_entropy(m(x), y)
            opt.zero_grad(); loss.backward(); opt.step()
        sch.step()
        if ep % 5 == 0 or ep == 1 or ep == epochs:
            acc = evaluate(m, test_loader, dev)
            print(f"  BP-frozen ep {ep}: test_acc={acc:.4f}", flush=True)
    return m


def train_dfa_frozen(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05):
    """4 transformer blocks frozen at random init.
    Trainable: patch_embed, cls_token, pos_embed, out_ln, out_head.
    DFA-style: head with true CE on cls token; embed (patch+cls+pos) with random feedback."""
    torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
    m = ViTMini(num_blocks=4, d_model=128, n_heads=4).to(dev)
    freeze_blocks(m)
    n_trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in m.parameters())
    print(f"DFA-frozen-blocks: {n_trainable}/{n_total} params trainable", flush=True)

    d_model, C = 128, 10
    B0 = torch.randn(d_model, C, device=dev) / np.sqrt(C)
    embed_opt = optim.AdamW(
        list(m.patch_embed.parameters()) + [m.cls_token, m.pos_embed],
        lr=lr, weight_decay=wd
    )
    head_opt = optim.AdamW(
        list(m.out_head.parameters()) + list(m.out_ln.parameters()),
        lr=lr, weight_decay=wd
    )
    sch1 = optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs)
    sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)
    for ep in range(1, epochs + 1):
        m.train()
        m.blocks.eval()
        for x, y in train_loader:
            x = x.to(dev); y = y.to(dev)
            with torch.no_grad():
                logits, hi = m(x, return_hidden=True)
                e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1
            hL_det = hi[-1].detach()
            # Head update via true CE on cls token
            h_cls = m.out_ln(hL_det[:, 0])
            head_opt.zero_grad()
            F.cross_entropy(m.out_head(h_cls), y).backward()
            head_opt.step()
            # Embed update via DFA feedback
            a0 = (e_T @ B0.T).detach()
            rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
            h0 = m.embed(x)
            a0_b = a0.unsqueeze(1).expand_as(h0)
            embed_loss = (h0 * (a0_b / rms.unsqueeze(1))).sum(-1).mean()
            embed_opt.zero_grad()
            embed_loss.backward()
            embed_opt.step()
        sch1.step(); sch2.step()
        if ep % 5 == 0 or ep == 1 or ep == epochs:
            acc = evaluate(m, test_loader, dev)
            print(f"  DFA-frozen ep {ep}: test_acc={acc:.4f}", flush=True)
    return m


def main():
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--epochs', type=int, default=30)
    args = p.parse_args()

    dev = torch.device('cuda:0')
    print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True)
    train_loader, test_loader = get_loaders(batch_size=128)

    print(f"\n=== BP frozen-blocks baseline (4 random-init transformer blocks, frozen), seed={args.seed} ===", flush=True)
    mb = train_bp_frozen(train_loader, test_loader, dev, epochs=args.epochs, seed=args.seed)
    bp_acc = evaluate(mb, test_loader, dev)
    print(f"FINAL BP-frozen-blocks acc: {bp_acc:.4f}", flush=True)

    print(f"\n=== DFA frozen-blocks baseline, seed={args.seed} ===", flush=True)
    md = train_dfa_frozen(train_loader, test_loader, dev, epochs=args.epochs, seed=args.seed)
    dfa_acc = evaluate(md, test_loader, dev)
    print(f"FINAL DFA-frozen-blocks acc: {dfa_acc:.4f}", flush=True)

    print(f"\n=== Summary ===")
    print(f"BP-frozen-blocks: {bp_acc:.4f}  (chance=0.10)")
    print(f"DFA-frozen-blocks: {dfa_acc:.4f}")
    print(f"Compare to ViT-Mini 4-block trainable (3-seed avg): BP=0.792, DFA=0.237")
    print(f"Compare to ViT-Mini 0-block (shallow baseline): BP=0.10, DFA=0.10")
    print()
    print("Interpretation:")
    print("  If DFA-frozen-blocks ≈ 0.237: blocks are passengers, DFA is just learning patch_embed+head")
    print("  If DFA-frozen-blocks << 0.237: trainable blocks ARE doing learned work")
    print("  If DFA-frozen-blocks ~ 0.10: untrained blocks add no useful mixing (less informative)")


if __name__ == '__main__':
    main()