diff options
Diffstat (limited to 'experiments/resnet_protocol_validation.py')
| -rw-r--r-- | experiments/resnet_protocol_validation.py | 343 |
1 files changed, 343 insertions, 0 deletions
diff --git a/experiments/resnet_protocol_validation.py b/experiments/resnet_protocol_validation.py new file mode 100644 index 0000000..f107231 --- /dev/null +++ b/experiments/resnet_protocol_validation.py @@ -0,0 +1,343 @@ +""" +Protocol validation on SmallResNet (BatchNorm, no LN) — BP/FA/DFA + frozen baseline. +Block-level DFA/FA: credit broadcast across spatial positions, same local loss as ResMLP. +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.small_resnet import SmallResNet +from metrics.credit_metrics import cosine_similarity_batch + + +def get_data(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) + + +def evaluate(model, loader, dev): + model.eval() + c = n = 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(dev), y.to(dev) + c += (model(x).argmax(-1) == y).sum().item() + n += x.size(0) + return c / n + + +def compute_diagnostics(model, x_eval, y_eval, device, method_name, dfa_Bs=None, fa_Bs=None): + """Compute per-layer cosine, ||g_l||, ||h_l|| for SmallResNet.""" + model.eval() + L = model.num_blocks + C = 10 + + # Hidden states + with torch.no_grad(): + _, hiddens = model(x_eval, return_hidden=True) + + # For ||h||: pool each hidden to (B, d) then take norm + hidden_norms = [] + for h in hiddens: + h_pool = F.adaptive_avg_pool2d(h, 1).flatten(1) # (B, d) + hidden_norms.append(float(h_pool.norm(dim=-1).median().item())) + + # BP grads via manual forward + h = model.stem(x_eval) + hs = [h.clone().requires_grad_(True)] + for block in model.blocks: + # Need to handle BN eval mode for frozen + hs.append(block(hs[-1])) + h_pool = F.adaptive_avg_pool2d(hs[-1], 1).flatten(1) + logits = model.out_head(h_pool) + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, hs) + + # ||g_l|| using pooled gradient + bp_grad_norms = [] + for g in grads: + g_pool = F.adaptive_avg_pool2d(g, 1).flatten(1) # (B, d) + bp_grad_norms.append(float(g_pool.norm(dim=-1).median().item())) + + # Per-layer cosine + with torch.no_grad(): + e_T = logits.softmax(-1) + e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 + + bp_cosine = [] + d = model.d_hidden + + if method_name == 'fa' and fa_Bs is not None: + # FA: sequential backward from exact pooled gradient + hL_pool_req = F.adaptive_avg_pool2d(hiddens[-1].detach(), 1).flatten(1).requires_grad_(True) + logits_fa = model.out_head(hL_pool_req) + loss_fa = F.cross_entropy(logits_fa, y_eval) + a_credit = torch.autograd.grad(loss_fa, hL_pool_req)[0].detach() + + for l in range(L - 1, -1, -1): + # Compare pooled credit with pooled BP grad + g_pool = F.adaptive_avg_pool2d(grads[l], 1).flatten(1).detach() + bp_cosine.insert(0, cosine_similarity_batch(a_credit, g_pool)) + a_credit = (a_credit @ fa_Bs[l]).detach() + + elif method_name == 'dfa' and dfa_Bs is not None: + for l in range(L): + a_dfa = (e_T @ dfa_Bs[l].T).detach() # (B, d) + g_pool = F.adaptive_avg_pool2d(grads[l], 1).flatten(1).detach() + bp_cosine.append(cosine_similarity_batch(a_dfa, g_pool)) + + elif method_name == 'bp': + bp_cosine = [1.0] * L + + model.train() + return { + 'bp_cosine': bp_cosine, + 'bp_grad_norms_per_layer': bp_grad_norms, + 'hidden_norms_per_layer': hidden_norms, + } + + +def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd): + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + for ep in range(1, epochs + 1): + model.train() + tl, tc, tn = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + logits = model(x) + loss = F.cross_entropy(logits, y) + opt.zero_grad(); loss.backward(); opt.step() + tl += loss.item() * x.size(0) + tc += (logits.argmax(1) == y).sum().item() + tn += x.size(0) + sch.step() + log['train_loss'].append(tl / tn) + log['train_acc'].append(tc / tn) + log['test_acc'].append(evaluate(model, test_loader, dev)) + if ep % 10 == 0 or ep == epochs: + print(f" [BP] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) + return log + + +def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd): + d = model.d_hidden + L = model.num_blocks + C = 10 + Bs = [torch.randn(d, C, device=dev) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] + stem_opt = optim.AdamW(list(model.stem_conv.parameters()) + list(model.stem_bn.parameters()), + lr=lr, weight_decay=wd) + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + for ep in range(1, epochs + 1): + model.train() + tl, tc, tn = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + # Head + hL_pool = F.adaptive_avg_pool2d(hiddens[-1].detach(), 1).flatten(1) + head_opt.zero_grad() + F.cross_entropy(model.out_head(hL_pool), y).backward() + head_opt.step() + # Blocks + for l in range(L): + h_l = hiddens[l].detach() + a_dfa = (e_T @ Bs[l].T).detach() # (B, d) + rms = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_norm = (a_dfa / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h_l) + f_l = model.blocks[l](h_l) - h_l # residual output only + local_loss = (f_l * a_norm).sum(dim=1).mean() + block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() + # Stem + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = model.stem(x) + a0_b = (a0 / rms0).unsqueeze(-1).unsqueeze(-1).expand_as(h0) + stem_opt.zero_grad() + (h0 * a0_b).sum(dim=1).mean().backward() + stem_opt.step() + for s in all_sch: s.step() + tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch + log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) + log['test_acc'].append(evaluate(model, test_loader, dev)) + if ep % 10 == 0 or ep == epochs: + print(f" [DFA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) + return log, Bs + + +def train_fa(model, train_loader, test_loader, dev, epochs, lr, wd): + d = model.d_hidden + L = model.num_blocks + Bs = [torch.randn(d, d, device=dev) / np.sqrt(d) for _ in range(L)] + block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] + stem_opt = optim.AdamW(list(model.stem_conv.parameters()) + list(model.stem_bn.parameters()), + lr=lr, weight_decay=wd) + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + for ep in range(1, epochs + 1): + model.train() + tl, tc, tn = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + # Head — get gradient BEFORE step + hL_pool = F.adaptive_avg_pool2d(hiddens[-1].detach(), 1).flatten(1).requires_grad_(True) + logits_out = model.out_head(hL_pool) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + a_credit = hL_pool.grad.detach() # (B, d) — pooled gradient + head_opt.step() + # Top-down block updates with FA credit + for l in range(L - 1, -1, -1): + h_l = hiddens[l].detach() + rms = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_norm = (a_credit / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h_l) + f_l = model.blocks[l](h_l) - h_l + local_loss = (f_l * a_norm).sum(dim=1).mean() + block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() + a_credit = (a_credit @ Bs[l]).detach() + # Stem + rms0 = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = model.stem(x) + a0_b = (a_credit / rms0).unsqueeze(-1).unsqueeze(-1).expand_as(h0) + stem_opt.zero_grad() + (h0 * a0_b).sum(dim=1).mean().backward() + stem_opt.step() + for s in all_sch: s.step() + tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch + log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) + log['test_acc'].append(evaluate(model, test_loader, dev)) + if ep % 10 == 0 or ep == epochs: + print(f" [FA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) + return log, Bs + + +def freeze_blocks(model): + for p in model.blocks.parameters(): + p.requires_grad_(False) + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() + + +def train_frozen(model, train_loader, test_loader, dev, epochs, lr, wd): + opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + model.train() + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + loss = F.cross_entropy(model(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 10 == 0 or ep == epochs: + acc = evaluate(model, test_loader, dev) + print(f" [Frozen] ep {ep}: acc={acc:.4f}", flush=True) + return evaluate(model, test_loader, dev) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--output', type=str, default='results/resnet_protocol_validation.json') + p.add_argument('--epochs', type=int, default=100) + p.add_argument('--d_hidden', type=int, default=64) + args = p.parse_args() + + dev = torch.device('cuda:0') + train_loader, test_loader = get_data(128) + + # Eval buffer for diagnostics (128 samples, consistent with cifar_resmlp.py) + xs, ys = [], [] + for x, y in test_loader: + xs.append(x); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: break + x_eval = torch.cat(xs)[:128].to(dev) + y_eval = torch.cat(ys)[:128].to(dev) + + results = {} + + for seed in [42, 123, 456]: + print(f"\n{'='*60}\nSeed {seed}\n{'='*60}", flush=True) + seed_results = {} + + # BP + print("\n--- BP ---", flush=True) + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = SmallResNet(args.d_hidden, 10, 4).to(dev) + bp_log = train_bp(model, train_loader, test_loader, dev, args.epochs, 1e-3, 0.01) + bp_diag = compute_diagnostics(model, x_eval, y_eval, dev, 'bp') + seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag} + + # FA + print("\n--- FA ---", flush=True) + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = SmallResNet(args.d_hidden, 10, 4).to(dev) + fa_log, fa_Bs = train_fa(model, train_loader, test_loader, dev, args.epochs, 1e-3, 0.01) + fa_diag = compute_diagnostics(model, x_eval, y_eval, dev, 'fa', fa_Bs=fa_Bs) + seed_results['fa'] = {'log': fa_log, 'diagnostics': fa_diag} + + # DFA + print("\n--- DFA ---", flush=True) + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = SmallResNet(args.d_hidden, 10, 4).to(dev) + dfa_log, dfa_Bs = train_dfa(model, train_loader, test_loader, dev, args.epochs, 1e-3, 0.01) + dfa_diag = compute_diagnostics(model, x_eval, y_eval, dev, 'dfa', dfa_Bs=dfa_Bs) + seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag} + + # Frozen baseline + print("\n--- Frozen ---", flush=True) + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = SmallResNet(args.d_hidden, 10, 4).to(dev) + freeze_blocks(model) + frozen_acc = train_frozen(model, train_loader, test_loader, dev, args.epochs, 1e-3, 0.01) + seed_results['frozen_acc'] = frozen_acc + print(f"FINAL frozen: {frozen_acc:.4f}", flush=True) + + results[str(seed)] = seed_results + + with open(args.output, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved: {args.output}", flush=True) + + +if __name__ == '__main__': + main() |
