summaryrefslogtreecommitdiff
path: root/protocol/example_usage.py
blob: 2b2c65e0653c0d668a3204b2492d1dff2aa746ac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Minimal example: apply the FA evaluation protocol to a DFA-trained ResMLP.

This script trains a model with DFA, then runs the three-diagnostic protocol.
Expected output: FAIL(D1+D2+D3) — DFA on terminal-LN ResMLP triggers all diagnostics.
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np

from models.residual_mlp import ResidualMLP
from protocol.fa_protocol import FAProtocol


def get_cifar10(batch_size=128):
    tv = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    tv_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
    te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
    return (torch.utils.data.DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
            torch.utils.data.DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))


def train_dfa(model, train_loader, device, epochs=30):
    """Minimal DFA training (canonical: no clipping, mean reduction)."""
    d = model.d_hidden
    L = model.num_blocks
    C = 10
    Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
    block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks]
    embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01)
    head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
                           lr=1e-3, weight_decay=0.01)
    for ep in range(1, epochs + 1):
        model.train()
        for x, y in train_loader:
            x = x.view(x.size(0), -1).to(device); y = y.to(device)
            batch = x.size(0)
            with torch.no_grad():
                logits, hiddens = model(x, return_hidden=True)
                e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
            hL = hiddens[-1].detach()
            head_opt.zero_grad()
            F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward()
            head_opt.step()
            for l in range(L):
                a = (e_T @ Bs[l].T).detach()
                rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
                f_l = model.blocks[l](hiddens[l].detach())
                loss = (f_l * (a / rms)).sum(-1).mean()
                block_opts[l].zero_grad(); loss.backward(); block_opts[l].step()
            a0 = (e_T @ Bs[0].T).detach()
            rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
            h0 = model.embed(x)
            embed_opt.zero_grad(); (h0 * (a0 / rms0)).sum(-1).mean().backward(); embed_opt.step()
        if ep % 10 == 0:
            print(f"  DFA ep {ep}/{epochs}", flush=True)


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    seed = 42
    torch.manual_seed(seed); np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    print("Loading CIFAR-10...")
    train_loader, test_loader = get_cifar10()

    # Prepare eval buffer
    xs, ys = [], []
    for x, y in test_loader:
        xs.append(x.view(x.size(0), -1)); ys.append(y)
        if sum(xb.size(0) for xb in xs) >= 128:
            break
    x_eval = torch.cat(xs)[:128].to(device)
    y_eval = torch.cat(ys)[:128].to(device)

    # Train with DFA
    print("Training DFA (30 epochs)...")
    model = ResidualMLP(3072, 256, 10, 4).to(device)
    train_dfa(model, train_loader, device, epochs=30)

    # Run protocol
    print("\nRunning protocol...")
    protocol = FAProtocol(model, x_eval, y_eval)
    report = protocol.run(frozen_baseline_acc=0.349)
    print(protocol.summary(report))


if __name__ == '__main__':
    main()