diff options
Diffstat (limited to 'protocol/example_usage.py')
| -rw-r--r-- | protocol/example_usage.py | 106 |
1 files changed, 106 insertions, 0 deletions
diff --git a/protocol/example_usage.py b/protocol/example_usage.py new file mode 100644 index 0000000..2b2c65e --- /dev/null +++ b/protocol/example_usage.py @@ -0,0 +1,106 @@ +""" +Minimal example: apply the FA evaluation protocol to a DFA-trained ResMLP. + +This script trains a model with DFA, then runs the three-diagnostic protocol. +Expected output: FAIL(D1+D2+D3) — DFA on terminal-LN ResMLP triggers all diagnostics. +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +import numpy as np + +from models.residual_mlp import ResidualMLP +from protocol.fa_protocol import FAProtocol + + +def get_cifar10(batch_size=128): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return (torch.utils.data.DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + torch.utils.data.DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) + + +def train_dfa(model, train_loader, device, epochs=30): + """Minimal DFA training (canonical: no clipping, mean reduction).""" + d = model.d_hidden + L = model.num_blocks + C = 10 + Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=1e-3, weight_decay=0.01) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + hL = hiddens[-1].detach() + head_opt.zero_grad() + F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward() + head_opt.step() + for l in range(L): + a = (e_T @ Bs[l].T).detach() + rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](hiddens[l].detach()) + loss = (f_l * (a / rms)).sum(-1).mean() + block_opts[l].zero_grad(); loss.backward(); block_opts[l].step() + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_opt.zero_grad(); (h0 * (a0 / rms0)).sum(-1).mean().backward(); embed_opt.step() + if ep % 10 == 0: + print(f" DFA ep {ep}/{epochs}", flush=True) + + +def main(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + seed = 42 + torch.manual_seed(seed); np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + print("Loading CIFAR-10...") + train_loader, test_loader = get_cifar10() + + # Prepare eval buffer + xs, ys = [], [] + for x, y in test_loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: + break + x_eval = torch.cat(xs)[:128].to(device) + y_eval = torch.cat(ys)[:128].to(device) + + # Train with DFA + print("Training DFA (30 epochs)...") + model = ResidualMLP(3072, 256, 10, 4).to(device) + train_dfa(model, train_loader, device, epochs=30) + + # Run protocol + print("\nRunning protocol...") + protocol = FAProtocol(model, x_eval, y_eval) + report = protocol.run(frozen_baseline_acc=0.349) + print(protocol.summary(report)) + + +if __name__ == '__main__': + main() |
