diff options
Diffstat (limited to 'experiments')
19 files changed, 2926 insertions, 0 deletions
diff --git a/experiments/__pycache__/__init__.cpython-313.pyc b/experiments/__pycache__/__init__.cpython-313.pyc Binary files differdeleted file mode 100644 index 5966841..0000000 --- a/experiments/__pycache__/__init__.cpython-313.pyc +++ /dev/null diff --git a/experiments/__pycache__/toy_lq.cpython-313.pyc b/experiments/__pycache__/toy_lq.cpython-313.pyc Binary files differdeleted file mode 100644 index d8710a8..0000000 --- a/experiments/__pycache__/toy_lq.cpython-313.pyc +++ /dev/null diff --git a/experiments/analyze_snapshot_evolution.py b/experiments/analyze_snapshot_evolution.py new file mode 100644 index 0000000..8b9f8af --- /dev/null +++ b/experiments/analyze_snapshot_evolution.py @@ -0,0 +1,60 @@ +""" +Read snapshot evolution JSONs (BP vs DFA over training epochs), summarize and +print comparison tables. Used for the P4 paper figure. + +Usage: + python experiments/analyze_snapshot_evolution.py <json_path> +""" +import sys, json +import numpy as np + + +def summarize(log, name): + eps = [d['epoch'] for d in log] + h_L = [d['hidden_norms'][-1] for d in log] + g_l2 = [d['bp_grad_per_sample_l2_med'][2] if 'bp_grad_per_sample_l2_med' in d + else d['bp_grad_norms_per_sample_med'][2] for d in log] + acc = [d['acc_eval'] for d in log] + print(f"\n{name} ({len(log)} epochs):") + print(f" ||h_L||_2 median: ep0={h_L[0]:.3e} -> ep{eps[len(eps)//2]}={h_L[len(eps)//2]:.3e} -> ep{eps[-1]}={h_L[-1]:.3e}") + print(f" ||BP grad at h_2||_2 median: ep0={g_l2[0]:.3e} -> ep{eps[len(eps)//2]}={g_l2[len(eps)//2]:.3e} -> ep{eps[-1]}={g_l2[-1]:.3e}") + print(f" acc: ep0={acc[0]:.4f} -> ep{eps[-1]}={acc[-1]:.4f}") + print(f" ||h_L|| growth (final/initial): {h_L[-1]/max(h_L[0], 1e-12):.3e}") + print(f" ||BP_g|| change (final/initial): {g_l2[-1]/max(g_l2[0], 1e-30):.3e}") + + +def main(): + path = sys.argv[1] if len(sys.argv) > 1 else 'results/snapshot_evolution_v2/snapshot_evolution_s42.json' + with open(path) as f: + d = json.load(f) + print(f"Loaded {path}") + print(f"config: {d.get('config', {})}") + print(f"depth={d.get('depth')}, d_hidden={d.get('d_hidden')}") + if 'bp_log' in d: + summarize(d['bp_log'], 'BP') + if 'dfa_log' in d: + summarize(d['dfa_log'], 'DFA') + + # Print compact per-epoch comparison if both available + if 'bp_log' in d and 'dfa_log' in d: + bp = d['bp_log'] + dfa = d['dfa_log'] + eps = sorted(set([x['epoch'] for x in bp]) & set([x['epoch'] for x in dfa])) + sample_eps = [eps[i] for i in [0, len(eps)//4, len(eps)//2, 3*len(eps)//4, -1]] + print(f"\nPer-epoch sample (BP vs DFA):") + print(f"{'epoch':>6s} {'BP_||h_L||':>12s} {'DFA_||h_L||':>12s} {'BP_||g_2||':>12s} {'DFA_||g_2||':>12s} {'BP_acc':>8s} {'DFA_acc':>8s}") + bp_d = {x['epoch']: x for x in bp} + dfa_d = {x['epoch']: x for x in dfa} + for e in sample_eps: + bdat = bp_d[e] + ddat = dfa_d[e] + bh = bdat['hidden_norms'][-1] + dh = ddat['hidden_norms'][-1] + bg_key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in bdat else 'bp_grad_norms_per_sample_med' + bg = bdat[bg_key][2] + dg = ddat[bg_key][2] + print(f"{e:>6d} {bh:>12.3e} {dh:>12.3e} {bg:>12.3e} {dg:>12.3e} {bdat['acc_eval']:>8.4f} {ddat['acc_eval']:>8.4f}") + + +if __name__ == '__main__': + main() diff --git a/experiments/dfa_penalty_freshB.py b/experiments/dfa_penalty_freshB.py new file mode 100644 index 0000000..82b192d --- /dev/null +++ b/experiments/dfa_penalty_freshB.py @@ -0,0 +1,183 @@ +""" +DFA canonical λ=1e-2 training + checkpoint save + fresh-B null calibration. +Runs after the main penalty sweep to produce the null calibration on the canonical checkpoint. +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +from metrics.credit_metrics import cosine_similarity_batch + + +def get_data(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) + + +def train_dfa_canonical(model, train_loader, device, epochs, lr, wd, penalty_lam): + """Canonical DFA from cifar_resmlp.py: no grad clipping, mean reduction.""" + d = model.d_hidden + L = model.num_blocks + C = 10 + Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd) + all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) + + for epoch in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + hL_det = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad(); loss_out.backward(); head_opt.step() + for l in range(L): + h_l = hiddens[l].detach() + a_dfa = (e_T @ Bs[l].T).detach() + rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](h_l) + local_loss = (f_l * (a_dfa / rms)).sum(dim=-1).mean() + if penalty_lam > 0: + local_loss = local_loss + penalty_lam * (f_l ** 2).sum(dim=-1).mean() + block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() + a_0 = (e_T @ Bs[0].T).detach() + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: s.step() + if epoch % 10 == 0 or epoch == epochs: + print(f" [DFA pen] ep {epoch}", flush=True) + return Bs + + +def compute_deep_cosine(model, Bs, x_eval, y_eval, device): + """Compute per-layer DFA cosine on eval buffer.""" + model.eval() + L = model.num_blocks + h0 = model.embed(x_eval.detach()) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: + hs.append(hs[-1] + b(hs[-1])) + logits = model.out_head(model.out_ln(hs[-1])) + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, hs) + with torch.no_grad(): + e_T = logits.softmax(-1) + e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 + cos_per_layer = [] + for l in range(L): + a_dfa = (e_T @ Bs[l].T).detach() + cos_per_layer.append(cosine_similarity_batch(a_dfa, grads[l].detach())) + acc = (logits.argmax(-1) == y_eval).float().mean().item() + g_norms = [g.norm(dim=-1).median().item() for g in grads] + h_norms = [h.detach().norm(dim=-1).median().item() for h in hs] + return cos_per_layer, acc, g_norms, h_norms + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--seed', type=int, default=42) + p.add_argument('--output_dir', type=str, default='results/dfa_canonical_freshB') + p.add_argument('--n_fresh', type=int, default=20) + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device('cuda:0') + train_loader, test_loader = get_data(128) + + # Fixed eval buffer + xs, ys = [], [] + for x, y in test_loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: + break + x_eval = torch.cat(xs)[:128].to(device) + y_eval = torch.cat(ys)[:128].to(device) + + L, d, C = 4, 256, 10 + + # Train DFA with λ=1e-2 + print(f"Training DFA canonical λ=0.01, seed={args.seed}", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + model = ResidualMLP(3072, d, C, L).to(device) + training_Bs = train_dfa_canonical(model, train_loader, device, 30, 1e-3, 0.01, 0.01) + + # Save checkpoint + ckpt_path = os.path.join(args.output_dir, f'dfa_canonical_lam0.01_s{args.seed}.pt') + torch.save({'state_dict': model.state_dict(), + 'Bs': [B.cpu() for B in training_Bs], + 'seed': args.seed}, ckpt_path) + print(f"Saved checkpoint: {ckpt_path}", flush=True) + + # Compute cosine with training Bs + cos_training, acc, g_norms, h_norms = compute_deep_cosine(model, training_Bs, x_eval, y_eval, device) + deep_cos_training = float(np.mean(cos_training[1:])) # exclude layer 0 + print(f"Training-Bs: acc={acc:.4f}, deep cos={deep_cos_training:+.4f}") + print(f" per-layer cos: {[f'{c:+.4f}' for c in cos_training]}") + print(f" ||g_l||: {[f'{g:.2e}' for g in g_norms]}") + print(f" ||h_l||: {[f'{h:.2e}' for h in h_norms]}") + + # Fresh-B null calibration + print(f"\nFresh-B null calibration ({args.n_fresh} draws)...", flush=True) + fresh_deep_cos = [] + fresh_per_layer = [] + for i in range(args.n_fresh): + fresh_Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + cos_fresh, _, _, _ = compute_deep_cosine(model, fresh_Bs, x_eval, y_eval, device) + deep_fresh = float(np.mean(cos_fresh[1:])) + fresh_deep_cos.append(deep_fresh) + fresh_per_layer.append(cos_fresh) + fresh_mean = np.mean(fresh_deep_cos) + fresh_std_ddof1 = np.std(fresh_deep_cos, ddof=1) + print(f"Fresh-Bs deep cos: {fresh_mean:+.4f} ± {fresh_std_ddof1:.4f} (ddof=1)") + + # Save results + out = { + 'description': f'Canonical DFA λ=0.01 s={args.seed} + fresh-B null (N={args.n_fresh})', + 'training_Bs_deep_cos': deep_cos_training, + 'training_Bs_per_layer_cos': cos_training, + 'training_Bs_acc': acc, + 'training_Bs_g_norms': g_norms, + 'training_Bs_h_norms': h_norms, + 'fresh_Bs_n_draws': args.n_fresh, + 'fresh_Bs_deep_cos_per_draw': fresh_deep_cos, + 'fresh_Bs_deep_mean': fresh_mean, + 'fresh_Bs_deep_std_ddof1': fresh_std_ddof1, + 'fresh_Bs_per_layer_mean': [float(np.mean([fl[l] for fl in fresh_per_layer])) for l in range(L)], + } + out_path = os.path.join(args.output_dir, f'freshB_null_canonical_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(out, f, indent=2) + print(f"Saved: {out_path}", flush=True) + + +if __name__ == '__main__': + main() diff --git a/experiments/dfa_penalty_trajectory.py b/experiments/dfa_penalty_trajectory.py new file mode 100644 index 0000000..c46ce0b --- /dev/null +++ b/experiments/dfa_penalty_trajectory.py @@ -0,0 +1,135 @@ +""" +Canonical DFA penalty trajectory: per-epoch ||h_L|| and ||g_L|| for λ ∈ {0, 1e-4, 1e-2}. +3 seeds × 3 λ × 30 epochs. Uses canonical cifar_resmlp.py DFA implementation (no clipping, mean reduction). +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP + + +def get_data(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) + + +def diagnose_quick(model, x_eval, y_eval): + model.eval() + x_flat = x_eval.view(x_eval.size(0), -1) + with torch.no_grad(): + logits, hiddens = model(x_flat, return_hidden=True) + h_L = hiddens[-1].norm(dim=-1).median().item() + # BP grad at h_L + h0 = model.embed(x_flat.detach()) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: + hs.append(hs[-1] + b(hs[-1])) + logits2 = model.out_head(model.out_ln(hs[-1])) + loss = F.cross_entropy(logits2, y_eval) + grads = torch.autograd.grad(loss, hs) + g_L = grads[-1].norm(dim=-1).median().item() + acc = (logits.argmax(-1) == y_eval).float().mean().item() + model.train() + return h_L, g_L, acc + + +def train_dfa_trajectory(seed, train_loader, x_eval, y_eval, device, epochs, lam): + L, d, C = 4, 256, 10 + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = ResidualMLP(3072, d, C, L).to(device) + Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(block.parameters(), lr=1e-3, weight_decay=0.01) for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=1e-3, weight_decay=0.01) + all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) + + log = [] + h_L, g_L, acc = diagnose_quick(model, x_eval, y_eval) + log.append({'epoch': 0, 'h_L': h_L, 'g_L': g_L, 'acc': acc}) + + for epoch in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + hL_det = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL_det)) + head_opt.zero_grad(); F.cross_entropy(logits_out, y).backward(); head_opt.step() + for l in range(L): + h_l = hiddens[l].detach() + a_dfa = (e_T @ Bs[l].T).detach() + rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](h_l) + local_loss = (f_l * (a_dfa / rms)).sum(dim=-1).mean() + if lam > 0: + local_loss = local_loss + lam * (f_l ** 2).sum(dim=-1).mean() + block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() + a_0 = (e_T @ Bs[0].T).detach() + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: s.step() + h_L, g_L, acc = diagnose_quick(model, x_eval, y_eval) + log.append({'epoch': epoch, 'h_L': h_L, 'g_L': g_L, 'acc': acc}) + if epoch % 10 == 0 or epoch == epochs: + print(f" [lam={lam}] s={seed} ep {epoch}: ||h_L||={h_L:.3e} ||g_L||={g_L:.3e} acc={acc:.4f}", flush=True) + return log + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--output', type=str, default='results/dfa_canonical_penalty_trajectory.json') + args = p.parse_args() + + device = torch.device('cuda:0') + train_loader, test_loader = get_data(128) + # Fixed 128-sample eval buffer (consistent with cifar_resmlp.py compute_diagnostics) + xs, ys = [], [] + for x, y in test_loader: + xs.append(x); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: + break + x_eval = torch.cat(xs)[:128].to(device) + y_eval = torch.cat(ys)[:128].to(device) + + results = {} + for lam in [0.0, 1e-4, 1e-2]: + lam_key = f'lam_{lam}' + results[lam_key] = {} + for seed in [42, 123, 456]: + print(f"\n=== λ={lam}, seed={seed} ===", flush=True) + log = train_dfa_trajectory(seed, train_loader, x_eval, y_eval, device, 30, lam) + results[lam_key][str(seed)] = log + + with open(args.output, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved: {args.output}", flush=True) + + +if __name__ == '__main__': + main() diff --git a/experiments/figure_snapshot_evolution.py b/experiments/figure_snapshot_evolution.py new file mode 100644 index 0000000..b06f417 --- /dev/null +++ b/experiments/figure_snapshot_evolution.py @@ -0,0 +1,178 @@ +""" +Generate the snapshot-evolution figure(s) for the paper from existing JSONs. + +Produces: + - figure_snapshot_resmlp.pdf : ResMLP with vs without out_ln, ||h_L|| and ||g|| + over epochs for BP and DFA + - figure_snapshot_vit.pdf : ViT-Mini ||h_L|| and ||g|| over epochs for BP/DFA + +Usage: + python experiments/figure_snapshot_evolution.py +""" +import os, sys, json +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + + +def load_log(path, log_key): + if not os.path.exists(path): + return None + with open(path) as f: + return json.load(f).get(log_key) + + +def trajectory(log, metric): + """Extract a per-epoch trajectory for the given metric.""" + eps = [r['epoch'] for r in log] + if metric == 'h_L': + # last hidden norm — handles both ResMLP (hidden_norms) and ViT (hidden_norms_cls) + values = [] + for r in log: + if 'hidden_norms_cls' in r: + values.append(r['hidden_norms_cls'][-1]) + else: + values.append(r['hidden_norms'][-1]) + elif metric == 'g_2': + values = [] + for r in log: + key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in r else 'bp_grad_norms_per_sample_med' + values.append(r[key][2]) + elif metric == 'acc': + values = [r['acc_eval'] for r in log] + elif metric == 'gamma_dfa': + values = [r.get('gamma_dfa', float('nan')) for r in log] + else: + return None, None + return np.array(eps), np.array(values) + + +def make_resmlp_figure(out_path): + fig, axes = plt.subplots(2, 2, figsize=(10, 7), sharex=True) + + runs = { + 'with_out_ln_s42': 'results/snapshot_evolution_v2/snapshot_evolution_s42.json', + 'no_out_ln_s42': 'results/snapshot_no_outln_v1/snapshot_noLN_s42.json', + 'no_out_ln_s123': 'results/snapshot_no_outln_v1/snapshot_noLN_s123.json', + 'no_out_ln_s456': 'results/snapshot_no_outln_v1/snapshot_noLN_s456.json', + } + runs_loaded = {k: (load_log(v, 'bp_log'), load_log(v, 'dfa_log')) for k, v in runs.items()} + + # Top row: with out_ln + bp, dfa = runs_loaded['with_out_ln_s42'] + ax = axes[0, 0] + e, v = trajectory(bp, 'h_L'); ax.plot(e, v, 'b-', label='BP', lw=2) + e, v = trajectory(dfa, 'h_L'); ax.plot(e, v, 'r-', label='DFA', lw=2) + ax.set_yscale('log'); ax.set_ylabel(r'$\|h_L\|_2$ (median)') + ax.set_title('ResMLP with terminal LayerNorm (s42)') + ax.legend(); ax.grid(True, alpha=0.3) + + ax = axes[0, 1] + e, v = trajectory(bp, 'g_2'); ax.plot(e, v, 'b-', label='BP', lw=2) + e, v = trajectory(dfa, 'g_2'); ax.plot(e, v, 'r-', label='DFA', lw=2) + ax.set_yscale('log'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)') + ax.set_title('ResMLP with terminal LayerNorm (s42)') + ax.legend(); ax.grid(True, alpha=0.3) + + # Bottom row: no out_ln, mean ± std across 3 seeds + no_ln_bp_h = []; no_ln_bp_g = []; no_ln_dfa_h = []; no_ln_dfa_g = [] + for k in ['no_out_ln_s42', 'no_out_ln_s123', 'no_out_ln_s456']: + bp, dfa = runs_loaded[k] + if bp is None or dfa is None: continue + e_bp, h_bp = trajectory(bp, 'h_L'); _, g_bp = trajectory(bp, 'g_2') + e_dfa, h_dfa = trajectory(dfa, 'h_L'); _, g_dfa = trajectory(dfa, 'g_2') + no_ln_bp_h.append(h_bp); no_ln_bp_g.append(g_bp) + no_ln_dfa_h.append(h_dfa); no_ln_dfa_g.append(g_dfa) + + if no_ln_bp_h: + eps = e_bp + bp_h_arr = np.array(no_ln_bp_h) + bp_g_arr = np.array(no_ln_bp_g) + dfa_h_arr = np.array(no_ln_dfa_h) + dfa_g_arr = np.array(no_ln_dfa_g) + + ax = axes[1, 0] + ax.plot(eps, np.mean(bp_h_arr, 0), 'b-', label='BP', lw=2) + ax.fill_between(eps, np.mean(bp_h_arr, 0)-np.std(bp_h_arr, 0), np.mean(bp_h_arr, 0)+np.std(bp_h_arr, 0), color='b', alpha=0.2) + ax.plot(eps, np.mean(dfa_h_arr, 0), 'r-', label='DFA', lw=2) + ax.fill_between(eps, np.mean(dfa_h_arr, 0)-np.std(dfa_h_arr, 0), np.mean(dfa_h_arr, 0)+np.std(dfa_h_arr, 0), color='r', alpha=0.2) + ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|h_L\|_2$ (median)') + ax.set_title(f'ResMLP WITHOUT terminal LayerNorm (mean ± std, n={len(no_ln_bp_h)})') + ax.legend(); ax.grid(True, alpha=0.3) + + ax = axes[1, 1] + ax.plot(eps, np.mean(bp_g_arr, 0), 'b-', label='BP', lw=2) + ax.fill_between(eps, np.mean(bp_g_arr, 0)-np.std(bp_g_arr, 0), np.mean(bp_g_arr, 0)+np.std(bp_g_arr, 0), color='b', alpha=0.2) + ax.plot(eps, np.mean(dfa_g_arr, 0), 'r-', label='DFA', lw=2) + ax.fill_between(eps, np.mean(dfa_g_arr, 0)-np.std(dfa_g_arr, 0), np.mean(dfa_g_arr, 0)+np.std(dfa_g_arr, 0), color='r', alpha=0.2) + ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)') + ax.set_title(f'ResMLP WITHOUT terminal LayerNorm (mean ± std, n={len(no_ln_bp_h)})') + ax.legend(); ax.grid(True, alpha=0.3) + + plt.suptitle('Snapshot evolution: residual stream + BP grad over training\n(top: with terminal LN — DFA explodes; bottom: no terminal LN — DFA still grows but BP grad does NOT collapse)', y=1.02) + plt.tight_layout() + plt.savefig(out_path, bbox_inches='tight', dpi=150) + print(f"Saved {out_path}") + plt.close() + + +def make_vit_figure(out_path): + fig, axes = plt.subplots(1, 2, figsize=(11, 4)) + + runs = sorted([ + f for f in os.listdir('results/snapshot_vit_v1') + if f.startswith('snapshot_vit_s') and f.endswith('.json') + ]) + if not runs: + print("No ViT snapshot JSONs found") + return + + bp_h_list = []; bp_g_list = []; dfa_h_list = []; dfa_g_list = [] + eps = None + for r in runs: + path = f'results/snapshot_vit_v1/{r}' + bp = load_log(path, 'bp_log') + dfa = load_log(path, 'dfa_log') + if bp is None or dfa is None: continue + e_bp, h_bp = trajectory(bp, 'h_L'); _, g_bp = trajectory(bp, 'g_2') + e_dfa, h_dfa = trajectory(dfa, 'h_L'); _, g_dfa = trajectory(dfa, 'g_2') + bp_h_list.append(h_bp); bp_g_list.append(g_bp) + dfa_h_list.append(h_dfa); dfa_g_list.append(g_dfa) + eps = e_bp + + bp_h_arr = np.array(bp_h_list); bp_g_arr = np.array(bp_g_list) + dfa_h_arr = np.array(dfa_h_list); dfa_g_arr = np.array(dfa_g_list) + + ax = axes[0] + ax.plot(eps, np.mean(bp_h_arr, 0), 'b-', label='BP', lw=2) + if len(bp_h_list) > 1: + ax.fill_between(eps, np.mean(bp_h_arr, 0)-np.std(bp_h_arr, 0), np.mean(bp_h_arr, 0)+np.std(bp_h_arr, 0), color='b', alpha=0.2) + ax.plot(eps, np.mean(dfa_h_arr, 0), 'r-', label='DFA', lw=2) + if len(dfa_h_list) > 1: + ax.fill_between(eps, np.mean(dfa_h_arr, 0)-np.std(dfa_h_arr, 0), np.mean(dfa_h_arr, 0)+np.std(dfa_h_arr, 0), color='r', alpha=0.2) + ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|h_L^{cls}\|_2$ (median)') + ax.set_title(f'ViT-Mini, terminal LayerNorm (n={len(bp_h_list)})') + ax.legend(); ax.grid(True, alpha=0.3) + + ax = axes[1] + ax.plot(eps, np.mean(bp_g_arr, 0), 'b-', label='BP', lw=2) + if len(bp_g_list) > 1: + ax.fill_between(eps, np.mean(bp_g_arr, 0)-np.std(bp_g_arr, 0), np.mean(bp_g_arr, 0)+np.std(bp_g_arr, 0), color='b', alpha=0.2) + ax.plot(eps, np.mean(dfa_g_arr, 0), 'r-', label='DFA', lw=2) + if len(dfa_g_list) > 1: + ax.fill_between(eps, np.mean(dfa_g_arr, 0)-np.std(dfa_g_arr, 0), np.mean(dfa_g_arr, 0)+np.std(dfa_g_arr, 0), color='r', alpha=0.2) + ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)') + ax.set_title(f'ViT-Mini, terminal LayerNorm (n={len(bp_g_list)})') + ax.legend(); ax.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(out_path, bbox_inches='tight', dpi=150) + print(f"Saved {out_path}") + plt.close() + + +if __name__ == '__main__': + os.makedirs('results/figures', exist_ok=True) + make_resmlp_figure('results/figures/figure_snapshot_resmlp.pdf') + make_vit_figure('results/figures/figure_snapshot_vit.pdf') diff --git a/experiments/frozen_baselines_crossarch.py b/experiments/frozen_baselines_crossarch.py new file mode 100644 index 0000000..a3dd76c --- /dev/null +++ b/experiments/frozen_baselines_crossarch.py @@ -0,0 +1,191 @@ +""" +Frozen-blocks baselines for ViT-Mini and StudentNet. +Trains only embed/head/LN with blocks frozen at random init. +Also trains shallow (no blocks) variant for comparison. +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.vit_mini import ViTMini +from experiments.confirmatory_paper_experiments import ( + StudentNet, TeacherNet, generate_synth_dataset, set_seed +) + + +def get_cifar10(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) + + +def evaluate(model, loader, device, is_vit=False): + model.eval() + c = n = 0 + with torch.no_grad(): + for x, y in loader: + x = x.to(device); y = y.to(device) + if not is_vit: + x = x.view(x.size(0), -1) if x.dim() == 4 else x + preds = model(x).argmax(-1) + c += (preds == y).sum().item() + n += x.size(0) + return c / n + + +def freeze_blocks(model): + for p in model.blocks.parameters(): + p.requires_grad_(False) + + +# ─── ViT-Mini frozen/shallow ──────────────────────────────────────────── + +def train_vit_frozen(seed, train_loader, test_loader, device, epochs, lr, wd): + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = ViTMini(d_model=128, n_heads=4, num_blocks=4, num_classes=10).to(device) + freeze_blocks(model) + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + print(f" ViT-Mini frozen: {trainable}/{total} trainable params", flush=True) + opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + loss = F.cross_entropy(model(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 10 == 0 or ep == epochs: + acc = evaluate(model, test_loader, device, is_vit=True) + print(f" [ViT-frozen] s={seed} ep {ep}: acc={acc:.4f}", flush=True) + return evaluate(model, test_loader, device, is_vit=True) + + +def train_vit_shallow(seed, train_loader, test_loader, device, epochs, lr, wd): + """ViT with num_blocks=0: just patch_embed + cls + pos + LN + head.""" + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = ViTMini(d_model=128, n_heads=4, num_blocks=0, num_classes=10).to(device) + trainable = sum(p.numel() for p in model.parameters()) + print(f" ViT-Mini shallow: {trainable} params (no blocks)", flush=True) + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + loss = F.cross_entropy(model(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 10 == 0 or ep == epochs: + acc = evaluate(model, test_loader, device, is_vit=True) + print(f" [ViT-shallow] s={seed} ep {ep}: acc={acc:.4f}", flush=True) + return evaluate(model, test_loader, device, is_vit=True) + + +# ─── StudentNet frozen/shallow ────────────────────────────────────────── + +def train_student_frozen(seed, train_loader, test_loader, device, epochs, lr, wd, alpha=1.0): + set_seed(seed) + model = StudentNet(128, 10, 4, alpha).to(device) + freeze_blocks(model) + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + print(f" StudentNet frozen: {trainable}/{total} trainable params", flush=True) + opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + loss = F.cross_entropy(model(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 10 == 0 or ep == epochs: + acc = evaluate(model, test_loader, device) + print(f" [Student-frozen] s={seed} ep {ep}: acc={acc:.4f}", flush=True) + return evaluate(model, test_loader, device) + + +def train_student_shallow(seed, train_loader, test_loader, device, epochs, lr, wd, alpha=1.0): + """StudentNet with num_blocks=0: just out_head (input is d_hidden already).""" + set_seed(seed) + model = StudentNet(128, 10, 0, alpha).to(device) + trainable = sum(p.numel() for p in model.parameters()) + print(f" StudentNet shallow: {trainable} params (no blocks)", flush=True) + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + loss = F.cross_entropy(model(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 10 == 0 or ep == epochs: + acc = evaluate(model, test_loader, device) + print(f" [Student-shallow] s={seed} ep {ep}: acc={acc:.4f}", flush=True) + return evaluate(model, test_loader, device) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--output', type=str, default='results/frozen_baselines_crossarch.json') + args = p.parse_args() + + device = torch.device('cuda:0') + + results = {} + + # ── ViT-Mini (CIFAR-10, 60 epochs) ── + print("\n=== ViT-Mini frozen baselines ===", flush=True) + train_loader, test_loader = get_cifar10(128) + for seed in [42, 123, 456]: + print(f"\n--- ViT-Mini seed={seed} ---", flush=True) + frozen_acc = train_vit_frozen(seed, train_loader, test_loader, device, 60, 1e-3, 0.05) + shallow_acc = train_vit_shallow(seed, train_loader, test_loader, device, 60, 1e-3, 0.05) + results[f'vit_frozen_s{seed}'] = frozen_acc + results[f'vit_shallow_s{seed}'] = shallow_acc + print(f" FINAL ViT s={seed}: frozen={frozen_acc:.4f}, shallow={shallow_acc:.4f}", flush=True) + + # ── StudentNet (synthetic, 80 epochs) ── + print("\n=== StudentNet frozen baselines ===", flush=True) + L, d, C, alpha = 4, 128, 10, 1.0 + for seed in [42, 123, 456]: + print(f"\n--- StudentNet seed={seed} ---", flush=True) + set_seed(seed) + teacher = TeacherNet(d, L, C, alpha, seed=0).to(device) + X_tr, Y_tr = generate_synth_dataset(teacher, 50*256, d, device, seed=seed) + X_te, Y_te = generate_synth_dataset(teacher, 2000, d, device, seed=seed+10000) + s_train = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=256, shuffle=True) + s_test = DataLoader(TensorDataset(X_te, Y_te), batch_size=256, shuffle=False) + + frozen_acc = train_student_frozen(seed, s_train, s_test, device, 80, 1e-3, 0.01, alpha) + shallow_acc = train_student_shallow(seed, s_train, s_test, device, 80, 1e-3, 0.01, alpha) + results[f'student_frozen_s{seed}'] = frozen_acc + results[f'student_shallow_s{seed}'] = shallow_acc + print(f" FINAL Student s={seed}: frozen={frozen_acc:.4f}, shallow={shallow_acc:.4f}", flush=True) + + with open(args.output, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved: {args.output}", flush=True) + + +if __name__ == '__main__': + main() diff --git a/experiments/resnet_frozen_blocks_baseline.py b/experiments/resnet_frozen_blocks_baseline.py new file mode 100644 index 0000000..787876d --- /dev/null +++ b/experiments/resnet_frozen_blocks_baseline.py @@ -0,0 +1,278 @@ +""" +Frozen-blocks and shallow baselines for a small CIFAR-10 ResNet (BatchNorm, +no LayerNorm) — codex-round-10 control to test whether the DFA "active-harm" +walk-back generalizes from LN-based architectures (ViT-Mini, ResMLP) to a +BN-based residual architecture. + +Conditions per seed: + - BP shallow (num_blocks=0) + - BP frozen-blocks (num_blocks=4 frozen) + - BP trainable (num_blocks=4) + - DFA shallow (num_blocks=0) + - DFA frozen-blocks (num_blocks=4 frozen) + - DFA trainable (num_blocks=4) + +If DFA-trainable < DFA-shallow on ResNet too → claim becomes "FA fails to train +deep blocks across multiple residual architectures including BN-based" — much +harder to dismiss as LN-specific. +If DFA-trainable ≈ or > DFA-shallow on ResNet → "harmful mode is specific to LN +normalization or terminal-LN architectures" — narrower but still useful claim. + +Usage: + CUDA_VISIBLE_DEVICES=2 python experiments/resnet_frozen_blocks_baseline.py --seed 42 +""" +import sys, os, argparse +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import numpy as np + +from models.small_resnet import SmallResNet + + +def get_loaders(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return ( + DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), + ) + + +def evaluate(model, loader, dev): + model.eval() + n = c = 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(dev), y.to(dev) + preds = model(x).argmax(-1) + c += (preds == y).sum().item() + n += x.size(0) + return c / n + + +def freeze_blocks(model): + for p in model.blocks.parameters(): + p.requires_grad_(False) + # Also keep BN running stats frozen by setting to eval() + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() + + +def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd, label, blocks_frozen=False): + opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + model.train() + if blocks_frozen: + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() # keep BN stats frozen + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + loss = F.cross_entropy(model(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 10 == 0 or ep == 1 or ep == epochs: + acc = evaluate(model, test_loader, dev) + print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True) + return model + + +def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label, blocks_frozen=False): + """DFA on the BN-ResNet: + - head trained with true CE on the pooled hidden state + - stem (conv + BN) trained via DFA-style local loss with random feedback + - blocks (if any) skipped (frozen for blocks_frozen=True; for trainable case, the + naive analog would be DFA-style local loss per block, but this script focuses on + the frozen/shallow comparison; for trainable comparison use the existing ResMLP + experiment as the analogous "trainable" since they share the same ad-hoc DFA pattern). + For this experiment we focus on the frozen and shallow conditions. + """ + d_hidden = model.d_hidden + L = max(model.num_blocks, 1) + C = 10 + Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)] + + stem_params = list(model.stem_conv.parameters()) + list(model.stem_bn.parameters()) + stem_opt = optim.AdamW(stem_params, lr=lr, weight_decay=wd) + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + sch1 = optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=epochs) + sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs) + + for ep in range(1, epochs + 1): + model.train() + if blocks_frozen: + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + with torch.no_grad(): + logits, hi = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1 + hL_det = hi[-1].detach() # (B, d_hidden, 32, 32) + # Head update via true CE on pooled cls + h_pool = F.adaptive_avg_pool2d(hL_det, 1).flatten(1) + head_opt.zero_grad() + F.cross_entropy(model.out_head(h_pool), y).backward() + head_opt.step() + # Stem update via DFA local loss + a0 = (e_T @ Bs[0].T).detach() # (B, d_hidden) + rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = model.stem(x) # (B, d_hidden, 32, 32) + # Broadcast credit across spatial positions: (B, d, 1, 1) -> (B, d, H, W) + a0_b = (a0 / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h0) + stem_loss = (h0 * a0_b).sum(dim=1).mean() # average over batch and spatial + stem_opt.zero_grad() + stem_loss.backward() + stem_opt.step() + sch1.step(); sch2.step() + if ep % 10 == 0 or ep == 1 or ep == epochs: + acc = evaluate(model, test_loader, dev) + print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True) + return model + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--epochs', type=int, default=60) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--d_hidden', type=int, default=64) + args = parser.parse_args() + + dev = torch.device('cuda:0') + print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True) + train_loader, test_loader = get_loaders(batch_size=128) + + results = {} + C = 10 + + # Trainable BP (full 4-block ResNet) + print(f"\n=== BP trainable (SmallResNet num_blocks=4), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev) + print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) + train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-trainable') + results['bp_trainable'] = evaluate(m, test_loader, dev) + print(f"FINAL BP-trainable: {results['bp_trainable']:.4f}", flush=True) + + # Trainable DFA — block-level DFA on ResNet (each block as a unit) + print(f"\n=== DFA trainable (SmallResNet num_blocks=4 block-level DFA), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev) + # We use the same approach as ViT/ResMLP: stem trained with DFA, blocks trained + # with their own DFA-style local loss per block, head with true CE. + # For simplicity reuse train_dfa logic but extend it to also train blocks. + # Since this script focuses on frozen/shallow control, we'll do trainable in a + # separate inner loop here. + d_hidden = m.d_hidden; L = m.num_blocks + Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=args.wd) for b in m.blocks] + stem_params = list(m.stem_conv.parameters()) + list(m.stem_bn.parameters()) + stem_opt = optim.AdamW(stem_params, lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW(m.out_head.parameters(), lr=args.lr, weight_decay=args.wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)] + for ep in range(1, args.epochs + 1): + m.train() + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + with torch.no_grad(): + logits, hi = m(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1 + hL_det = hi[-1].detach() + h_pool = F.adaptive_avg_pool2d(hL_det, 1).flatten(1) + head_opt.zero_grad() + F.cross_entropy(m.out_head(h_pool), y).backward() + head_opt.step() + for l in range(L): + h_l = hi[l].detach() + a_l = (e_T @ Bs[l].T).detach() + rms = (a_l ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_l_norm = (a_l / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h_l) + f_l = m.blocks[l](h_l) + local_loss = (f_l * a_l_norm).sum(dim=1).mean() + block_opts[l].zero_grad(); local_loss.backward() + torch.nn.utils.clip_grad_norm_(m.blocks[l].parameters(), 1.0) + block_opts[l].step() + a_0 = (e_T @ Bs[0].T).detach() + rms_0 = (a_0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = m.stem(x) + a_0_b = (a_0 / rms_0).unsqueeze(-1).unsqueeze(-1).expand_as(h0) + stem_loss = (h0 * a_0_b).sum(dim=1).mean() + stem_opt.zero_grad(); stem_loss.backward(); stem_opt.step() + for s in all_sch: s.step() + if ep % 10 == 0 or ep == 1 or ep == args.epochs: + acc = evaluate(m, test_loader, dev) + print(f" [DFA-trainable] ep {ep}: test_acc={acc:.4f}", flush=True) + results['dfa_trainable'] = evaluate(m, test_loader, dev) + print(f"FINAL DFA-trainable: {results['dfa_trainable']:.4f}", flush=True) + + # BP shallow + print(f"\n=== BP shallow (SmallResNet num_blocks=0), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=0).to(dev) + print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) + train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-shallow') + results['bp_shallow'] = evaluate(m, test_loader, dev) + print(f"FINAL BP-shallow: {results['bp_shallow']:.4f}", flush=True) + + # BP frozen-blocks + print(f"\n=== BP frozen-blocks (SmallResNet num_blocks=4 frozen), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev) + freeze_blocks(m) + print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) + train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-frozen', blocks_frozen=True) + results['bp_frozen'] = evaluate(m, test_loader, dev) + print(f"FINAL BP-frozen-blocks: {results['bp_frozen']:.4f}", flush=True) + + # DFA shallow + print(f"\n=== DFA shallow (SmallResNet num_blocks=0), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=0).to(dev) + train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow') + results['dfa_shallow'] = evaluate(m, test_loader, dev) + print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True) + + # DFA frozen-blocks + print(f"\n=== DFA frozen-blocks (SmallResNet num_blocks=4 frozen), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev) + freeze_blocks(m) + train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen', blocks_frozen=True) + results['dfa_frozen'] = evaluate(m, test_loader, dev) + print(f"FINAL DFA-frozen-blocks: {results['dfa_frozen']:.4f}", flush=True) + + print(f"\n=== Small ResNet (BatchNorm) frozen/shallow baseline summary, seed={args.seed} ===") + for k, v in results.items(): + print(f" {k}: {v:.4f}") + print(f"\nKey gaps (DFA):") + if 'dfa_shallow' in results and 'dfa_trainable' in results: + print(f" DFA-shallow ({results['dfa_shallow']:.4f}) - DFA-trainable ({results['dfa_trainable']:.4f}) = {results['dfa_shallow']-results['dfa_trainable']:+.4f}") + if 'dfa_frozen' in results and 'dfa_trainable' in results: + print(f" DFA-frozen ({results['dfa_frozen']:.4f}) - DFA-trainable ({results['dfa_trainable']:.4f}) = {results['dfa_frozen']-results['dfa_trainable']:+.4f}") + + +if __name__ == '__main__': + main() diff --git a/experiments/resnet_protocol_validation.py b/experiments/resnet_protocol_validation.py new file mode 100644 index 0000000..f107231 --- /dev/null +++ b/experiments/resnet_protocol_validation.py @@ -0,0 +1,343 @@ +""" +Protocol validation on SmallResNet (BatchNorm, no LN) — BP/FA/DFA + frozen baseline. +Block-level DFA/FA: credit broadcast across spatial positions, same local loss as ResMLP. +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.small_resnet import SmallResNet +from metrics.credit_metrics import cosine_similarity_batch + + +def get_data(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) + + +def evaluate(model, loader, dev): + model.eval() + c = n = 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(dev), y.to(dev) + c += (model(x).argmax(-1) == y).sum().item() + n += x.size(0) + return c / n + + +def compute_diagnostics(model, x_eval, y_eval, device, method_name, dfa_Bs=None, fa_Bs=None): + """Compute per-layer cosine, ||g_l||, ||h_l|| for SmallResNet.""" + model.eval() + L = model.num_blocks + C = 10 + + # Hidden states + with torch.no_grad(): + _, hiddens = model(x_eval, return_hidden=True) + + # For ||h||: pool each hidden to (B, d) then take norm + hidden_norms = [] + for h in hiddens: + h_pool = F.adaptive_avg_pool2d(h, 1).flatten(1) # (B, d) + hidden_norms.append(float(h_pool.norm(dim=-1).median().item())) + + # BP grads via manual forward + h = model.stem(x_eval) + hs = [h.clone().requires_grad_(True)] + for block in model.blocks: + # Need to handle BN eval mode for frozen + hs.append(block(hs[-1])) + h_pool = F.adaptive_avg_pool2d(hs[-1], 1).flatten(1) + logits = model.out_head(h_pool) + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, hs) + + # ||g_l|| using pooled gradient + bp_grad_norms = [] + for g in grads: + g_pool = F.adaptive_avg_pool2d(g, 1).flatten(1) # (B, d) + bp_grad_norms.append(float(g_pool.norm(dim=-1).median().item())) + + # Per-layer cosine + with torch.no_grad(): + e_T = logits.softmax(-1) + e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 + + bp_cosine = [] + d = model.d_hidden + + if method_name == 'fa' and fa_Bs is not None: + # FA: sequential backward from exact pooled gradient + hL_pool_req = F.adaptive_avg_pool2d(hiddens[-1].detach(), 1).flatten(1).requires_grad_(True) + logits_fa = model.out_head(hL_pool_req) + loss_fa = F.cross_entropy(logits_fa, y_eval) + a_credit = torch.autograd.grad(loss_fa, hL_pool_req)[0].detach() + + for l in range(L - 1, -1, -1): + # Compare pooled credit with pooled BP grad + g_pool = F.adaptive_avg_pool2d(grads[l], 1).flatten(1).detach() + bp_cosine.insert(0, cosine_similarity_batch(a_credit, g_pool)) + a_credit = (a_credit @ fa_Bs[l]).detach() + + elif method_name == 'dfa' and dfa_Bs is not None: + for l in range(L): + a_dfa = (e_T @ dfa_Bs[l].T).detach() # (B, d) + g_pool = F.adaptive_avg_pool2d(grads[l], 1).flatten(1).detach() + bp_cosine.append(cosine_similarity_batch(a_dfa, g_pool)) + + elif method_name == 'bp': + bp_cosine = [1.0] * L + + model.train() + return { + 'bp_cosine': bp_cosine, + 'bp_grad_norms_per_layer': bp_grad_norms, + 'hidden_norms_per_layer': hidden_norms, + } + + +def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd): + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + for ep in range(1, epochs + 1): + model.train() + tl, tc, tn = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + logits = model(x) + loss = F.cross_entropy(logits, y) + opt.zero_grad(); loss.backward(); opt.step() + tl += loss.item() * x.size(0) + tc += (logits.argmax(1) == y).sum().item() + tn += x.size(0) + sch.step() + log['train_loss'].append(tl / tn) + log['train_acc'].append(tc / tn) + log['test_acc'].append(evaluate(model, test_loader, dev)) + if ep % 10 == 0 or ep == epochs: + print(f" [BP] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) + return log + + +def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd): + d = model.d_hidden + L = model.num_blocks + C = 10 + Bs = [torch.randn(d, C, device=dev) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] + stem_opt = optim.AdamW(list(model.stem_conv.parameters()) + list(model.stem_bn.parameters()), + lr=lr, weight_decay=wd) + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + for ep in range(1, epochs + 1): + model.train() + tl, tc, tn = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + # Head + hL_pool = F.adaptive_avg_pool2d(hiddens[-1].detach(), 1).flatten(1) + head_opt.zero_grad() + F.cross_entropy(model.out_head(hL_pool), y).backward() + head_opt.step() + # Blocks + for l in range(L): + h_l = hiddens[l].detach() + a_dfa = (e_T @ Bs[l].T).detach() # (B, d) + rms = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_norm = (a_dfa / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h_l) + f_l = model.blocks[l](h_l) - h_l # residual output only + local_loss = (f_l * a_norm).sum(dim=1).mean() + block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() + # Stem + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = model.stem(x) + a0_b = (a0 / rms0).unsqueeze(-1).unsqueeze(-1).expand_as(h0) + stem_opt.zero_grad() + (h0 * a0_b).sum(dim=1).mean().backward() + stem_opt.step() + for s in all_sch: s.step() + tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch + log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) + log['test_acc'].append(evaluate(model, test_loader, dev)) + if ep % 10 == 0 or ep == epochs: + print(f" [DFA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) + return log, Bs + + +def train_fa(model, train_loader, test_loader, dev, epochs, lr, wd): + d = model.d_hidden + L = model.num_blocks + Bs = [torch.randn(d, d, device=dev) / np.sqrt(d) for _ in range(L)] + block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] + stem_opt = optim.AdamW(list(model.stem_conv.parameters()) + list(model.stem_bn.parameters()), + lr=lr, weight_decay=wd) + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + for ep in range(1, epochs + 1): + model.train() + tl, tc, tn = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + # Head — get gradient BEFORE step + hL_pool = F.adaptive_avg_pool2d(hiddens[-1].detach(), 1).flatten(1).requires_grad_(True) + logits_out = model.out_head(hL_pool) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + a_credit = hL_pool.grad.detach() # (B, d) — pooled gradient + head_opt.step() + # Top-down block updates with FA credit + for l in range(L - 1, -1, -1): + h_l = hiddens[l].detach() + rms = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_norm = (a_credit / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h_l) + f_l = model.blocks[l](h_l) - h_l + local_loss = (f_l * a_norm).sum(dim=1).mean() + block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() + a_credit = (a_credit @ Bs[l]).detach() + # Stem + rms0 = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = model.stem(x) + a0_b = (a_credit / rms0).unsqueeze(-1).unsqueeze(-1).expand_as(h0) + stem_opt.zero_grad() + (h0 * a0_b).sum(dim=1).mean().backward() + stem_opt.step() + for s in all_sch: s.step() + tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch + log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) + log['test_acc'].append(evaluate(model, test_loader, dev)) + if ep % 10 == 0 or ep == epochs: + print(f" [FA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) + return log, Bs + + +def freeze_blocks(model): + for p in model.blocks.parameters(): + p.requires_grad_(False) + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() + + +def train_frozen(model, train_loader, test_loader, dev, epochs, lr, wd): + opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + model.train() + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + loss = F.cross_entropy(model(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 10 == 0 or ep == epochs: + acc = evaluate(model, test_loader, dev) + print(f" [Frozen] ep {ep}: acc={acc:.4f}", flush=True) + return evaluate(model, test_loader, dev) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--output', type=str, default='results/resnet_protocol_validation.json') + p.add_argument('--epochs', type=int, default=100) + p.add_argument('--d_hidden', type=int, default=64) + args = p.parse_args() + + dev = torch.device('cuda:0') + train_loader, test_loader = get_data(128) + + # Eval buffer for diagnostics (128 samples, consistent with cifar_resmlp.py) + xs, ys = [], [] + for x, y in test_loader: + xs.append(x); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: break + x_eval = torch.cat(xs)[:128].to(dev) + y_eval = torch.cat(ys)[:128].to(dev) + + results = {} + + for seed in [42, 123, 456]: + print(f"\n{'='*60}\nSeed {seed}\n{'='*60}", flush=True) + seed_results = {} + + # BP + print("\n--- BP ---", flush=True) + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = SmallResNet(args.d_hidden, 10, 4).to(dev) + bp_log = train_bp(model, train_loader, test_loader, dev, args.epochs, 1e-3, 0.01) + bp_diag = compute_diagnostics(model, x_eval, y_eval, dev, 'bp') + seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag} + + # FA + print("\n--- FA ---", flush=True) + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = SmallResNet(args.d_hidden, 10, 4).to(dev) + fa_log, fa_Bs = train_fa(model, train_loader, test_loader, dev, args.epochs, 1e-3, 0.01) + fa_diag = compute_diagnostics(model, x_eval, y_eval, dev, 'fa', fa_Bs=fa_Bs) + seed_results['fa'] = {'log': fa_log, 'diagnostics': fa_diag} + + # DFA + print("\n--- DFA ---", flush=True) + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = SmallResNet(args.d_hidden, 10, 4).to(dev) + dfa_log, dfa_Bs = train_dfa(model, train_loader, test_loader, dev, args.epochs, 1e-3, 0.01) + dfa_diag = compute_diagnostics(model, x_eval, y_eval, dev, 'dfa', dfa_Bs=dfa_Bs) + seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag} + + # Frozen baseline + print("\n--- Frozen ---", flush=True) + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = SmallResNet(args.d_hidden, 10, 4).to(dev) + freeze_blocks(model) + frozen_acc = train_frozen(model, train_loader, test_loader, dev, args.epochs, 1e-3, 0.01) + seed_results['frozen_acc'] = frozen_acc + print(f"FINAL frozen: {frozen_acc:.4f}", flush=True) + + results[str(seed)] = seed_results + + with open(args.output, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved: {args.output}", flush=True) + + +if __name__ == '__main__': + main() diff --git a/experiments/snapshot_compare_outln.py b/experiments/snapshot_compare_outln.py new file mode 100644 index 0000000..9b0ac7c --- /dev/null +++ b/experiments/snapshot_compare_outln.py @@ -0,0 +1,93 @@ +""" +Compare snapshot evolution JSONs across with-out_ln vs no-out_ln conditions +and across seeds. Produces summary tables for the P4 figure. + +Usage: + python experiments/snapshot_compare_outln.py +""" +import os, sys, json, glob +import numpy as np + + +def load(path): + if not os.path.exists(path): + return None + with open(path) as f: + return json.load(f) + + +def field(d, key, layer=2): + """Extract a per-epoch list of values for a given metric/layer.""" + if d is None: + return None + log = d['bp_log'] if 'bp' in key else d['dfa_log'] if 'dfa' in key else None + metric = key.replace('bp_', '').replace('dfa_', '') + if metric == 'h_L_norm': + return [r['hidden_norms'][-1] for r in log] + if metric == 'h_L2_norm': + return [r['hidden_norms'][2] if len(r['hidden_norms']) > 2 else None for r in log] + if metric == 'g_l2': + key_in_log = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in log[0] else 'bp_grad_norms_per_sample_med' + return [r[key_in_log][layer] for r in log] + if metric == 'acc': + return [r['acc_eval'] for r in log] + if metric == 'gamma_dfa': + return [r.get('gamma_dfa', float('nan')) for r in log] + return None + + +def summary_row(d, label): + """Print a summary row for the comparison table.""" + if d is None: + print(f"{label:35s} MISSING") + return + bp = d['bp_log'] + dfa = d['dfa_log'] + bp_eps = [r['epoch'] for r in bp] + dfa_eps = [r['epoch'] for r in dfa] + bp_h_L_init = bp[0]['hidden_norms'][-1] + bp_h_L_final = bp[-1]['hidden_norms'][-1] + dfa_h_L_init = dfa[0]['hidden_norms'][-1] + dfa_h_L_final = dfa[-1]['hidden_norms'][-1] + bp_g_key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in bp[0] else 'bp_grad_norms_per_sample_med' + bp_g2_init = bp[0][bp_g_key][2] + bp_g2_final = bp[-1][bp_g_key][2] + dfa_g2_init = dfa[0][bp_g_key][2] + dfa_g2_final = dfa[-1][bp_g_key][2] + bp_acc = bp[-1]['acc_eval'] + dfa_acc = dfa[-1]['acc_eval'] + bp_growth = bp_h_L_final / max(bp_h_L_init, 1e-12) + dfa_growth = dfa_h_L_final / max(dfa_h_L_init, 1e-12) + bp_g_change = bp_g2_final / max(bp_g2_init, 1e-30) + dfa_g_change = dfa_g2_final / max(dfa_g2_init, 1e-30) + + print(f"{label:35s} BP_acc={bp_acc:.3f} DFA_acc={dfa_acc:.3f} " + f"BP_||h_L||: {bp_h_L_init:.1e}→{bp_h_L_final:.1e} (×{bp_growth:.1e}) " + f"DFA_||h_L||: {dfa_h_L_init:.1e}→{dfa_h_L_final:.1e} (×{dfa_growth:.1e}) " + f"BP_||g_2||: {bp_g2_init:.1e}→{bp_g2_final:.1e} " + f"DFA_||g_2||: {dfa_g2_init:.1e}→{dfa_g2_final:.1e}") + + +def main(): + print("=" * 130) + print("SNAPSHOT EVOLUTION COMPARISON: with-out_ln vs no-out_ln vs synthetic") + print("=" * 130) + + runs = [ + ('with-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_evolution_v2/snapshot_evolution_s42.json'), + ('no-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s42.json'), + ('no-out_ln s123 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s123.json'), + ('no-out_ln s456 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s456.json'), + ('synthetic α=1 s42 (StudentNet)', 'results/snapshot_synth_v1/snapshot_synth_a1.0_L4_s42.json'), + ] + for label, path in runs: + d = load(path) + summary_row(d, label) + + print() + print("Legend: ||h_L|| = median per-sample L2 norm of final hidden state; ||g_2|| = median per-sample L2 norm of BP gradient at h_2.") + print("All norms use .norm(dim=-1), correct.") + + +if __name__ == '__main__': + main() diff --git a/experiments/snapshot_evolution_no_outln.py b/experiments/snapshot_evolution_no_outln.py new file mode 100644 index 0000000..312a4cb --- /dev/null +++ b/experiments/snapshot_evolution_no_outln.py @@ -0,0 +1,249 @@ +""" +Snapshot evolution on a NO-out_ln variant of the standard ResidualMLP. +Same architecture as ResidualMLP but with the terminal LayerNorm removed +(head reads h_L directly). Trains BP and DFA from scratch on CIFAR-10 and +logs ||h_l||_2 + ||BP grad||_2 per epoch. + +This is the architectural causal control for P4: if removing out_ln from the +SAME architecture rescues the residual-stream pathology, then out_ln is +causally responsible (not just correlated). + +Usage: + CUDA_VISIBLE_DEVICES=2 nohup python experiments/snapshot_evolution_no_outln.py \ + --output_dir results/snapshot_no_outln_v1 --epochs 100 --seed 42 \ + > results/snapshot_no_outln_v1/run_s42.log 2>&1 & +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from metrics.credit_metrics import cosine_similarity_batch + + +class ResidualBlockPreLN(nn.Module): + """Same as models/residual_mlp.ResidualBlock — pre-LN MLP block.""" + def __init__(self, d_hidden: int): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.w1 = nn.Linear(d_hidden, d_hidden) + self.w2 = nn.Linear(d_hidden, d_hidden) + nn.init.normal_(self.w2.weight, std=0.01) + nn.init.zeros_(self.w2.bias) + def forward(self, h): + z = self.ln(h) + z = self.w1(z) + z = F.gelu(z) + z = self.w2(z) + return z + + +class ResidualMLP_NoOutLN(nn.Module): + """Like ResidualMLP, but WITHOUT out_ln. Head reads h_L directly.""" + def __init__(self, input_dim, d_hidden, num_classes, num_blocks): + super().__init__() + self.embed = nn.Linear(input_dim, d_hidden) + self.blocks = nn.ModuleList([ResidualBlockPreLN(d_hidden) for _ in range(num_blocks)]) + # NO out_ln + self.out_head = nn.Linear(d_hidden, num_classes) + self.num_blocks = num_blocks + self.d_hidden = d_hidden + + def forward(self, x, return_hidden=False): + h = self.embed(x) + hiddens = [h] if return_hidden else None + for block in self.blocks: + f = block(h) + h = h + f + if return_hidden: + hiddens.append(h) + logits = self.out_head(h) # NO out_ln + if return_hidden: + return logits, hiddens + return logits + + +def get_cifar10(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) + + +def fixed_eval_buffer(test_loader, device, n_samples=1024): + xs, ys = [], [] + for x, y in test_loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= n_samples: + break + return torch.cat(xs)[:n_samples].to(device), torch.cat(ys)[:n_samples].to(device) + + +def diagnose(model, x_eval, y_eval, dfa_Bs=None): + was_training = model.training + model.eval() + L = model.num_blocks + with torch.no_grad(): + _, hi = model(x_eval, return_hidden=True) + hidden_norms = [h.norm(dim=-1).median().item() for h in hi] + + h0 = model.embed(x_eval.detach()) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: + hs.append(hs[-1] + b(hs[-1])) + logits = model.out_head(hs[-1]) # NO out_ln + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, hs) + bp_l2 = [g.norm(dim=-1).median().item() for g in grads] + bp_full = [g.detach() for g in grads] + acc = (logits.argmax(-1) == y_eval).float().mean().item() + loss_val = loss.item() + + gamma_dfa = float('nan'); per_layer_gamma = [] + if dfa_Bs is not None: + with torch.no_grad(): + e_T = logits.softmax(-1); e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 + for l in range(L): + a_dfa = (e_T @ dfa_Bs[l].T).detach() + per_layer_gamma.append(cosine_similarity_batch(a_dfa, bp_full[l])) + gamma_dfa = float(np.mean(per_layer_gamma)) + + if was_training: model.train() + return { + 'hidden_norms': hidden_norms, + 'bp_grad_per_sample_l2_med': bp_l2, + 'gamma_dfa': gamma_dfa, + 'gamma_dfa_per_layer': per_layer_gamma, + 'acc_eval': acc, 'loss_eval': loss_val, + } + + +def train_bp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd): + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + log = [] + d0 = diagnose(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0) + print(f" [BP-noLN] Ep 0: ||h_L||={d0['hidden_norms'][-1]:.3e} ||g||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + logits = model(x); loss = F.cross_entropy(logits, y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + d = diagnose(model, x_eval, y_eval); d['epoch'] = ep; log.append(d) + if ep % 5 == 0 or ep == 1 or ep == epochs: + print(f" [BP-noLN] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} ||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f}", flush=True) + return log + + +def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd): + d_hidden = model.d_hidden; L = model.num_blocks; C = 10 + Bs = [torch.randn(d_hidden, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = [] + d0 = diagnose(model, x_eval, y_eval, dfa_Bs=Bs); d0['epoch'] = 0; log.append(d0) + print(f" [DFA-noLN] Ep 0: ||h_L||={d0['hidden_norms'][-1]:.3e} ||g||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + hL_det = hiddens[-1].detach() + # Head update — NO out_ln + logits_out = model.out_head(hL_det) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad(); loss_out.backward(); head_opt.step() + # Block updates + for l in range(L): + h_l = hiddens[l].detach() + a_dfa = (e_T @ Bs[l].T).detach() + rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_dfa / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad(); local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + # Embed update + a_0 = (e_T @ Bs[0].T).detach() + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: s.step() + d = diagnose(model, x_eval, y_eval, dfa_Bs=Bs); d['epoch'] = ep; log.append(d) + if ep % 5 == 0 or ep == 1 or ep == epochs: + print(f" [DFA-noLN] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} ||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f} γ={d['gamma_dfa']:.4f}", flush=True) + return log + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--output_dir', type=str, default='results/snapshot_no_outln_v1') + p.add_argument('--epochs', type=int, default=100) + p.add_argument('--lr', type=float, default=1e-3) + p.add_argument('--wd', type=float, default=0.01) + p.add_argument('--seed', type=int, default=42) + p.add_argument('--depth', type=int, default=4) + p.add_argument('--d_hidden', type=int, default=256) + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device('cuda:0') + print(f"NO-OUT_LN VARIANT: depth={args.depth}, d_hidden={args.d_hidden}, " + f"epochs={args.epochs}, seed={args.seed}", flush=True) + + train_loader, test_loader = get_cifar10(batch_size=128) + x_eval, y_eval = fixed_eval_buffer(test_loader, device, n_samples=1024) + + L, d, C = args.depth, args.d_hidden, 10 + + print("\n=== BP training (NO out_ln) ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + bp_model = ResidualMLP_NoOutLN(3072, d, C, L).to(device) + bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd) + + print("\n=== DFA training (NO out_ln) ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + dfa_model = ResidualMLP_NoOutLN(3072, d, C, L).to(device) + dfa_log = train_dfa(dfa_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd) + + out = { + 'config': vars(args), 'depth': L, 'd_hidden': d, 'num_classes': C, + 'architecture': 'ResidualMLP_NoOutLN', + 'bp_log': bp_log, 'dfa_log': dfa_log, + } + out_path = os.path.join(args.output_dir, f'snapshot_noLN_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(out, f, indent=2) + print(f"\nSaved {out_path}", flush=True) + + +if __name__ == '__main__': + main() diff --git a/experiments/snapshot_evolution_residual_explosion.py b/experiments/snapshot_evolution_residual_explosion.py index 1dc09f2..6155d94 100644 --- a/experiments/snapshot_evolution_residual_explosion.py +++ b/experiments/snapshot_evolution_residual_explosion.py @@ -212,6 +212,73 @@ def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_e return log +def train_fa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1): + """FA (Lillicrap 2016): sequential backward credit with d×d random matrices. + Canonical implementation matching cifar_resmlp.py train_fa(): + - mean reduction (default) + - gradient taken BEFORE head step (old head weights) + - top-down block update, credit propagated after each block + - NO grad clipping + """ + d_hidden = model.d_hidden + L = model.num_blocks + Bs_fa = [torch.randn(d_hidden, d_hidden, device=device) / np.sqrt(d_hidden) for _ in range(L)] + block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd) + all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) + log = [] + d0 = diagnose(model, x_eval, y_eval) + d0['epoch'] = 0 + log.append(d0) + print(f" [FA] Ep 0: ||h||_med={d0['hidden_norms']} acc={d0['acc_eval']:.4f}", flush=True) + for epoch in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + # Forward + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + # Head update — get gradient BEFORE step (old head weights) + hL_det = hiddens[-1].detach().requires_grad_(True) + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) # mean reduction + head_opt.zero_grad() + loss_out.backward() + a_credit = hL_det.grad.detach() # gradient w.r.t. old head + head_opt.step() + # Top-down block updates with sequential FA credit propagation + for l in range(L - 1, -1, -1): + h_l = hiddens[l].detach() + rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_credit / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() # no grad clipping + a_credit = (a_credit @ Bs_fa[l]).detach() + # Embed update with final propagated credit + rms_0 = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_loss = (h0 * (a_credit / rms_0)).sum(dim=-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: + s.step() + if epoch % log_every == 0 or epoch == epochs: + d = diagnose(model, x_eval, y_eval) + d['epoch'] = epoch + log.append(d) + print(f" [FA] Ep {epoch}: ||h_L||={d['hidden_norms'][-1]:.3e} " + f"||g_2||={d['bp_grad_norms_per_sample_med'][2]:.3e} " + f"acc={d['acc_eval']:.4f}", flush=True) + return log + + def main(): p = argparse.ArgumentParser() p.add_argument('--output_dir', type=str, default='results/snapshot_evolution_v2') @@ -262,11 +329,22 @@ def main(): args.epochs, args.lr, args.wd, log_every=args.log_every, random_targets=args.random_targets) + fa_log = None + if not args.skip_bp and not args.random_targets: # FA only when doing full run + print("\n=== FA training ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + fa_model = ResidualMLP(3072, d, C, L, + residual_add=not args.no_residual_add, + w2_std=args.w2_std).to(device) + fa_log = train_fa(fa_model, train_loader, x_eval, y_eval, device, + args.epochs, args.lr, args.wd, log_every=args.log_every) + out = { 'config': vars(args), 'depth': L, 'd_hidden': d, 'num_classes': C, 'bp_log': bp_log, 'dfa_log': dfa_log, + 'fa_log': fa_log, } out_path = os.path.join(args.output_dir, f'snapshot_evolution_s{args.seed}.json') with open(out_path, 'w') as f: diff --git a/experiments/snapshot_evolution_vit.py b/experiments/snapshot_evolution_vit.py new file mode 100644 index 0000000..ce4c090 --- /dev/null +++ b/experiments/snapshot_evolution_vit.py @@ -0,0 +1,244 @@ +""" +Snapshot evolution on a ViT-Mini (modern transformer-style architecture) trained +with BP and block-level DFA on CIFAR-10. Logs ||h_l||, ||BP grad||, Γ per epoch. + +This is the P4 generalization test: does the residual-stream pathology + LayerNorm +gradient collapse mechanism (verified on pre-LN ResMLP with terminal LN) also +appear on an actual transformer architecture? If yes → strong P4 in modern setting. + +Block-level DFA: each TransformerBlock is a "layer". The DFA credit +`a_l = e_T @ B_l^T` is broadcast across all tokens at that block's input. The +local block loss is `<block_l(h_l), broadcast(a_l)>` summed over tokens. + +Usage: + CUDA_VISIBLE_DEVICES=2 nohup python experiments/snapshot_evolution_vit.py \ + --output_dir results/snapshot_vit_v1 --epochs 60 --seed 42 \ + > results/snapshot_vit_v1/run_s42.log 2>&1 & +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.vit_mini import ViTMini, TransformerBlock +from metrics.credit_metrics import cosine_similarity_batch + + +def get_cifar10(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) + + +def fixed_eval_buffer(test_loader, device, n_samples=1024): + xs, ys = [], [] + for x, y in test_loader: + xs.append(x); ys.append(y) + if sum(xb.size(0) for xb in xs) >= n_samples: + break + return torch.cat(xs)[:n_samples].to(device), torch.cat(ys)[:n_samples].to(device) + + +def diagnose(model, x_eval, y_eval, dfa_Bs=None): + """Compute per-block ||h_l|| and ||BP grad at h_l||, plus optional Γ vs DFA credit.""" + was_training = model.training + model.eval() + L = model.num_blocks + + # Hidden states (no grad) + with torch.no_grad(): + _, hiddens = model(x_eval, return_hidden=True) + # hiddens[l] is shape (B, n_tokens, d_model) + # Reduce to per-sample by taking the cls-token norm OR by flattening across tokens + # We'll report cls-token norm (the one that actually flows to the head) + hidden_norms_cls = [h[:, 0].norm(dim=-1).median().item() for h in hiddens] + hidden_norms_avg = [h.norm(dim=-1).mean().item() for h in hiddens] # avg across tokens then over batch + + # BP gradients + h0 = model.embed(x_eval.detach()) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: + hs.append(b(hs[-1])) + h_cls = model.out_ln(hs[-1][:, 0]) + logits = model.out_head(h_cls) + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, hs) + # grads[l] is shape (B, n_tokens, d_model) + # Per-sample L2 norm: take Frobenius over tokens × d_model + bp_grad_per_sample_l2 = [g.flatten(1).norm(dim=-1).median().item() for g in grads] + bp_grad_F = [g.norm().item() for g in grads] + bp_full = [g.detach() for g in grads] + + acc = (logits.argmax(-1) == y_eval).float().mean().item() + loss_val = loss.item() + + gamma_dfa = float('nan'); per_layer_gamma = [] + if dfa_Bs is not None: + with torch.no_grad(): + e_T = logits.softmax(-1); e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 + for l in range(L): + # Block-level DFA credit: per-sample (B, d_model), broadcast to (B, n_tokens, d_model) + a_dfa_per_sample = (e_T @ dfa_Bs[l].T).detach() # (B, d_model) + a_dfa_broadcast = a_dfa_per_sample.unsqueeze(1).expand_as(bp_full[l]) # (B, n_tokens, d_model) + # Cosine using flattened (per-sample) representation + per_layer_gamma.append(cosine_similarity_batch( + a_dfa_broadcast.flatten(1), bp_full[l].flatten(1))) + gamma_dfa = float(np.mean(per_layer_gamma)) + + if was_training: + model.train() + + return { + 'hidden_norms_cls': hidden_norms_cls, + 'hidden_norms_avg': hidden_norms_avg, + 'bp_grad_per_sample_l2_med': bp_grad_per_sample_l2, + 'bp_grad_F': bp_grad_F, + 'gamma_dfa': gamma_dfa, + 'gamma_dfa_per_layer': per_layer_gamma, + 'acc_eval': acc, + 'loss_eval': loss_val, + } + + +def train_bp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd): + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + log = [] + d0 = diagnose(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0) + print(f" [BP-vit] Ep 0: ||h_L_cls||={d0['hidden_norms_cls'][-1]:.3e} ||g_2||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + logits = model(x); loss = F.cross_entropy(logits, y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + d = diagnose(model, x_eval, y_eval); d['epoch'] = ep; log.append(d) + if ep % 5 == 0 or ep == 1 or ep == epochs: + print(f" [BP-vit] Ep {ep}: ||h_L_cls||={d['hidden_norms_cls'][-1]:.3e} ||g_2||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f}", flush=True) + return log + + +def train_dfa_block_level(model, train_loader, x_eval, y_eval, device, epochs, lr, wd): + """Block-level DFA on ViT. Each TransformerBlock is treated as a unit; DFA credit + is broadcast across all tokens at the block's input. + """ + d_model = model.d_hidden + L = model.num_blocks + C = 10 + Bs = [torch.randn(d_model, C, device=device) / np.sqrt(C) for _ in range(L)] + + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + embed_opt = optim.AdamW( + list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed], + lr=lr, weight_decay=wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = [] + d0 = diagnose(model, x_eval, y_eval, dfa_Bs=Bs); d0['epoch'] = 0; log.append(d0) + print(f" [DFA-vit] Ep 0: ||h_L_cls||={d0['hidden_norms_cls'][-1]:.3e} ||g_2||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + hL_det = hiddens[-1].detach() + # Head update via direct CE on cls token + h_cls = model.out_ln(hL_det[:, 0]) + logits_out = model.out_head(h_cls) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad(); loss_out.backward(); head_opt.step() + # Block updates: each block's local loss = <block(h_l), a_dfa_broadcast> + for l in range(L): + h_l = hiddens[l].detach() # (B, n_tokens, d) + a_dfa = (e_T @ Bs[l].T).detach() # (B, d) + a_dfa_broadcast = a_dfa.unsqueeze(1).expand_as(h_l) # (B, n_tokens, d) + rms = (a_dfa_broadcast ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_dfa_broadcast / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad(); local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + # Embed update (patch embed + cls + pos) + a_0 = (e_T @ Bs[0].T).detach() + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) # (B, n_tokens, d) + a_0_broadcast = a_0.unsqueeze(1).expand_as(h0) + embed_loss = (h0 * (a_0_broadcast / rms_0.unsqueeze(1))).sum(dim=-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: s.step() + d = diagnose(model, x_eval, y_eval, dfa_Bs=Bs); d['epoch'] = ep; log.append(d) + if ep % 5 == 0 or ep == 1 or ep == epochs: + print(f" [DFA-vit] Ep {ep}: ||h_L_cls||={d['hidden_norms_cls'][-1]:.3e} ||g_2||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f} γ={d['gamma_dfa']:.4f}", flush=True) + return log + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--output_dir', type=str, default='results/snapshot_vit_v1') + p.add_argument('--epochs', type=int, default=60) + p.add_argument('--lr', type=float, default=1e-3) + p.add_argument('--wd', type=float, default=0.05) + p.add_argument('--seed', type=int, default=42) + p.add_argument('--depth', type=int, default=4) + p.add_argument('--d_model', type=int, default=128) + p.add_argument('--n_heads', type=int, default=4) + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device('cuda:0') + print(f"ViT-MINI: depth={args.depth}, d_model={args.d_model}, n_heads={args.n_heads}, " + f"epochs={args.epochs}, seed={args.seed}", flush=True) + + train_loader, test_loader = get_cifar10(batch_size=128) + x_eval, y_eval = fixed_eval_buffer(test_loader, device, n_samples=1024) + + print("\n=== BP training (ViT-Mini) ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + bp_model = ViTMini(num_blocks=args.depth, d_model=args.d_model, n_heads=args.n_heads).to(device) + print(f" n_params={sum(p.numel() for p in bp_model.parameters())}", flush=True) + bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd) + + print("\n=== DFA training (ViT-Mini, block-level DFA) ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + dfa_model = ViTMini(num_blocks=args.depth, d_model=args.d_model, n_heads=args.n_heads).to(device) + dfa_log = train_dfa_block_level(dfa_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd) + + out = { + 'config': vars(args), 'depth': args.depth, 'd_model': args.d_model, + 'architecture': 'ViTMini', 'bp_log': bp_log, 'dfa_log': dfa_log, + } + out_path = os.path.join(args.output_dir, f'snapshot_vit_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(out, f, indent=2) + print(f"\nSaved {out_path}", flush=True) + + +if __name__ == '__main__': + main() diff --git a/experiments/snapshot_fa_crossarch.py b/experiments/snapshot_fa_crossarch.py new file mode 100644 index 0000000..8fa9e71 --- /dev/null +++ b/experiments/snapshot_fa_crossarch.py @@ -0,0 +1,243 @@ +""" +FA-only snapshot evolution for ViT-Mini and ResMLP-no-outLN. +Produces per-epoch ||h_L||, ||g_L||, acc for FA training. +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +from models.vit_mini import ViTMini + + +def get_cifar10(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) + + +def fixed_eval_buffer(loader, device, n=1024): + xs, ys = [], [] + for x, y in loader: + xs.append(x); ys.append(y) + if sum(xb.size(0) for xb in xs) >= n: + break + return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device) + + +# ─── Diagnose (works for both ViT and ResMLP) ─────────────────────────── + +def diagnose_resmlp(model, x_eval, y_eval): + model.eval() + x_flat = x_eval.view(x_eval.size(0), -1) + with torch.no_grad(): + _, hiddens = model(x_flat, return_hidden=True) + hidden_norms = [h.norm(dim=-1).median().item() for h in hiddens] + # BP grads + h0 = model.embed(x_flat.detach()) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: + hs.append(hs[-1] + b(hs[-1])) + # Handle both with and without out_ln + if hasattr(model, 'out_ln'): + logits = model.out_head(model.out_ln(hs[-1])) + else: + logits = model.out_head(hs[-1]) + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, hs) + g_norms = [g.norm(dim=-1).median().item() for g in grads] + acc = (logits.argmax(-1) == y_eval).float().mean().item() + model.train() + return {'hidden_norms': hidden_norms, 'bp_grad_norms_per_sample_med': g_norms, 'acc_eval': acc} + + +def diagnose_vit(model, x_eval, y_eval): + model.eval() + with torch.no_grad(): + _, hiddens = model(x_eval, return_hidden=True) + h_cls_norms = [h[:, 0].norm(dim=-1).median().item() for h in hiddens] + # BP grads via manual forward + h0 = model.embed(x_eval.detach()) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: + hs.append(hs[-1] + b(hs[-1])) + h_cls = model.out_ln(hs[-1][:, 0]) + logits = model.out_head(h_cls) + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, hs) + g_cls_norms = [g[:, 0].norm(dim=-1).median().item() for g in grads] + acc = (logits.argmax(-1) == y_eval).float().mean().item() + model.train() + return {'hidden_norms_cls': h_cls_norms, 'bp_grad_per_sample_l2_med': g_cls_norms, 'acc_eval': acc} + + +# ─── FA training ───────────────────────────────────────────────────────── + +def train_fa_resmlp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, no_outln=False): + d_hidden = model.d_hidden + L = model.num_blocks + Bs = [torch.randn(d_hidden, d_hidden, device=device) / np.sqrt(d_hidden) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_params = list(model.out_head.parameters()) + if hasattr(model, 'out_ln') and model.out_ln is not None: + head_params += list(model.out_ln.parameters()) + head_opt = optim.AdamW(head_params, lr=lr, weight_decay=wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = [] + d0 = diagnose_resmlp(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0) + print(f" [FA] Ep 0: acc={d0['acc_eval']:.4f}", flush=True) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + hL_det = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL_det)) if hasattr(model, 'out_ln') else model.out_head(hL_det) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad(); loss_out.backward(); head_opt.step() + # FA credits + hL_req = hiddens[-1].detach().requires_grad_(True) + logits_fa = model.out_head(model.out_ln(hL_req)) if hasattr(model, 'out_ln') else model.out_head(hL_req) + loss_fa = F.cross_entropy(logits_fa, y, reduction='sum') + a_L = torch.autograd.grad(loss_fa, hL_req)[0].detach() + credits = [None] * L + credits[L-1] = a_L + for ll in range(L-2, -1, -1): + credits[ll] = (credits[ll+1] @ Bs[ll+1]).detach() + for l in range(L): + h_l = hiddens[l].detach() + a_l = credits[l] + rms = (a_l**2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](h_l) + local_loss = (f_l * (a_l / rms)).sum(dim=-1).mean() + block_opts[l].zero_grad(); local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + a_0 = credits[0] + rms_0 = (a_0**2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: s.step() + d = diagnose_resmlp(model, x_eval, y_eval); d['epoch'] = ep; log.append(d) + if ep % 10 == 0 or ep == 1 or ep == epochs: + print(f" [FA] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} " + f"||g_L||={d['bp_grad_norms_per_sample_med'][-1]:.3e} " + f"acc={d['acc_eval']:.4f}", flush=True) + return log + + +def train_fa_vit(model, train_loader, x_eval, y_eval, device, epochs, lr, wd): + """Canonical FA for ViT: mean reduction, grad before step, no clipping, top-down.""" + d_model = model.d_hidden + L = model.num_blocks + Bs = [torch.randn(d_model, d_model, device=device) / np.sqrt(d_model) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + embed_opt = optim.AdamW( + list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed], + lr=lr, weight_decay=wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = [] + d0 = diagnose_vit(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0) + print(f" [FA-vit] Ep 0: acc={d0['acc_eval']:.4f}", flush=True) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + # Head update — grad BEFORE step (old head) + hL_det = hiddens[-1].detach().requires_grad_(True) + h_cls = model.out_ln(hL_det[:, 0]) + logits_out = model.out_head(h_cls) + loss_out = F.cross_entropy(logits_out, y) # mean reduction + head_opt.zero_grad() + loss_out.backward() + a_L_full = hL_det.grad.detach() # (B, n_tokens, d) + head_opt.step() + # Use mean over tokens for the backward signal + a_credit = a_L_full.mean(dim=1) # (B, d) + # Top-down block updates, propagate credit after each + for l in range(L - 1, -1, -1): + h_l = hiddens[l].detach() + a_broadcast = a_credit.unsqueeze(1).expand_as(h_l) + rms = (a_broadcast ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](h_l) + local_loss = (f_l * (a_broadcast / rms)).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() # no clipping + a_credit = (a_credit @ Bs[l]).detach() + # Embed update with final propagated credit + a_0_broadcast = a_credit.unsqueeze(1) + rms_0 = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_loss = (h0 * (a_0_broadcast / rms_0.unsqueeze(1))).sum(dim=-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: s.step() + d = diagnose_vit(model, x_eval, y_eval); d['epoch'] = ep; log.append(d) + if ep % 5 == 0 or ep == 1 or ep == epochs: + print(f" [FA-vit] Ep {ep}: ||h_L||={d['hidden_norms_cls'][-1]:.3e} " + f"||g_L||={d['bp_grad_per_sample_l2_med'][-1]:.3e} " + f"acc={d['acc_eval']:.4f}", flush=True) + return log + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--arch', choices=['vit', 'resmlp_noln'], required=True) + p.add_argument('--output', type=str, required=True) + p.add_argument('--epochs', type=int, default=100) + p.add_argument('--seed', type=int, default=42) + args = p.parse_args() + + device = torch.device('cuda:0') + train_loader, test_loader = get_cifar10(128) + x_eval, y_eval = fixed_eval_buffer(test_loader, device, 1024) + + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + if args.arch == 'vit': + # Match ViT snapshot params + model = ViTMini(d_model=128, n_heads=4, num_blocks=4, num_classes=10).to(device) + fa_log = train_fa_vit(model, train_loader, x_eval, y_eval, device, + args.epochs, lr=1e-3, wd=0.05) + else: + # ResMLP without terminal LN — use the same class as the original no-outln experiment + from experiments.snapshot_evolution_no_outln import ResidualMLP_NoOutLN + model = ResidualMLP_NoOutLN(3072, 256, 10, 4).to(device) + fa_log = train_fa_resmlp(model, train_loader, x_eval, y_eval, device, + args.epochs, lr=1e-3, wd=0.01, no_outln=True) + + with open(args.output, 'w') as f: + json.dump({'fa_log': fa_log, 'arch': args.arch, 'seed': args.seed}, f, indent=2) + print(f"Saved: {args.output}", flush=True) + + +if __name__ == '__main__': + main() diff --git a/experiments/snapshot_fa_only.py b/experiments/snapshot_fa_only.py new file mode 100644 index 0000000..cdc69ae --- /dev/null +++ b/experiments/snapshot_fa_only.py @@ -0,0 +1,38 @@ +"""Quick FA-only snapshot evolution. Reuses the full script's train_fa + diagnose.""" +import os, sys, json, argparse +import numpy as np +import torch +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from experiments.snapshot_evolution_residual_explosion import ( + get_cifar10, fixed_eval_buffer, train_fa +) +from models.residual_mlp import ResidualMLP + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--output', type=str, required=True) + p.add_argument('--epochs', type=int, default=100) + p.add_argument('--seed', type=int, default=42) + p.add_argument('--depth', type=int, default=4) + p.add_argument('--d_hidden', type=int, default=256) + args = p.parse_args() + + device = torch.device('cuda:0') + train_loader, test_loader = get_cifar10(batch_size=128) + x_eval, y_eval = fixed_eval_buffer(test_loader, device, n_samples=1024) + + L, d, C = args.depth, args.d_hidden, 10 + print(f"FA snapshot: depth={L}, d={d}, seed={args.seed}, epochs={args.epochs}", flush=True) + + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + model = ResidualMLP(3072, d, C, L).to(device) + fa_log = train_fa(model, train_loader, x_eval, y_eval, device, + args.epochs, 1e-3, 0.01, log_every=1) + + with open(args.output, 'w') as f: + json.dump({'fa_log': fa_log, 'seed': args.seed, 'depth': L, 'd_hidden': d}, f, indent=2) + print(f"Saved: {args.output}", flush=True) + +if __name__ == '__main__': + main() diff --git a/experiments/snapshot_fa_studentnet.py b/experiments/snapshot_fa_studentnet.py new file mode 100644 index 0000000..887365c --- /dev/null +++ b/experiments/snapshot_fa_studentnet.py @@ -0,0 +1,94 @@ +"""FA-only snapshot evolution for StudentNet (synthetic teacher-student).""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from experiments.confirmatory_paper_experiments import ( + StudentNet, TeacherNet, generate_synth_dataset, set_seed +) +from experiments.snapshot_synth_residual_explosion import diagnose_synth + + +def train_fa_synth(model, train_loader, x_eval, y_eval, device, epochs, lr, wd): + """Canonical FA for StudentNet: mean reduction, grad before step, no clipping.""" + d_hidden = model.d_hidden + L = model.num_blocks + Bs = [torch.randn(d_hidden, d_hidden, device=device) / np.sqrt(d_hidden) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = [] + d0 = diagnose_synth(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0) + print(f" [FA] Ep 0: acc={d0['acc_eval']:.4f}", flush=True) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + # Head update — grad BEFORE step (old head) + hL_det = hiddens[-1].detach().requires_grad_(True) + logits_out = model.out_head(hL_det) + loss_out = F.cross_entropy(logits_out, y) # mean reduction + head_opt.zero_grad() + loss_out.backward() + a_credit = hL_det.grad.detach() + head_opt.step() + # Top-down block updates, propagate credit after each + for l in range(L - 1, -1, -1): + h_l = hiddens[l].detach() + rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](h_l) + local_loss = (f_l * (a_credit / rms)).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() # no clipping + a_credit = (a_credit @ Bs[l]).detach() + # No embed for StudentNet (input is already d_hidden) + for s in all_sch: s.step() + d = diagnose_synth(model, x_eval, y_eval); d['epoch'] = ep; log.append(d) + if ep % 5 == 0 or ep in (1, epochs): + print(f" [FA] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} " + f"||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} " + f"acc={d['acc_eval']:.4f}", flush=True) + return log + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--output', type=str, required=True) + p.add_argument('--epochs', type=int, default=80) + p.add_argument('--seed', type=int, default=42) + p.add_argument('--alpha', type=float, default=1.0) + p.add_argument('--depth', type=int, default=4) + p.add_argument('--d_hidden', type=int, default=128) + args = p.parse_args() + + device = torch.device('cuda:0') + L, d, C = args.depth, args.d_hidden, 10 + set_seed(args.seed) + teacher = TeacherNet(d, L, C, args.alpha, seed=0).to(device) + X_tr, Y_tr = generate_synth_dataset(teacher, 50*256, d, device, seed=args.seed) + X_te, Y_te = generate_synth_dataset(teacher, 2000, d, device, seed=args.seed+10000) + train_loader = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=256, shuffle=True) + + print(f"StudentNet FA: alpha={args.alpha}, L={L}, d={d}, seed={args.seed}", flush=True) + set_seed(args.seed) + model = StudentNet(d, C, L, args.alpha).to(device) + fa_log = train_fa_synth(model, train_loader, X_te.to(device), Y_te.to(device), + device, args.epochs, 1e-3, 0.01) + + with open(args.output, 'w') as f: + json.dump({'fa_log': fa_log, 'seed': args.seed, 'alpha': args.alpha, + 'depth': L, 'd_hidden': d}, f, indent=2) + print(f"Saved: {args.output}", flush=True) + + +if __name__ == '__main__': + main() diff --git a/experiments/snapshot_synth_residual_explosion.py b/experiments/snapshot_synth_residual_explosion.py new file mode 100644 index 0000000..3470667 --- /dev/null +++ b/experiments/snapshot_synth_residual_explosion.py @@ -0,0 +1,195 @@ +""" +Synthetic snapshot evolution: per-epoch logging of ||h_l||_2 and ||BP grad||_2 +on a teacher-student StudentNet (NO out_ln) trained with BP vs DFA. + +Goal: test whether the residual-stream explosion observed in CIFAR ResidualMLP +(pre-LN with out_ln before head) also happens in the synthetic StudentNet +architecture (no out_ln; head reads h_L directly). If synthetic does NOT show +the explosion, then out_ln is causally responsible for the CIFAR pathology and +the paper's P4 claim narrows to "pre-LN architectures with terminal LN". + +Usage: + CUDA_VISIBLE_DEVICES=2 nohup python experiments/snapshot_synth_residual_explosion.py \ + --output_dir results/snapshot_synth_v1 --epochs 80 --alpha 1.0 --depth 4 --seed 42 \ + > results/snapshot_synth_v1/run_a1.0_s42.log 2>&1 & +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from metrics.credit_metrics import cosine_similarity_batch +# Import the StudentNet/TeacherNet/generate_synth_dataset directly from confirmatory script +from experiments.confirmatory_paper_experiments import ( + StudentNet, TeacherNet, generate_synth_dataset, set_seed +) + + +def diagnose_synth(model, x_eval, y_eval, dfa_Bs=None): + was_training = model.training + model.eval() + L = model.num_blocks + + with torch.no_grad(): + _, hi = model(x_eval, return_hidden=True) + hidden_norms = [h.norm(dim=-1).median().item() for h in hi] + + # BP grads + h_list = [x_eval.detach().requires_grad_(True)] + for block in model.blocks: + h_list.append(h_list[-1] + block(h_list[-1])) + logits = model.out_head(h_list[-1]) + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, h_list) + bp_grad_l2 = [g.norm(dim=-1).median().item() for g in grads] + bp_grad_F = [g.norm().item() for g in grads] + bp_full = [g.detach() for g in grads] + acc = (logits.argmax(-1) == y_eval).float().mean().item() + loss_val = loss.item() + + gamma_dfa = float('nan') + per_layer_gamma = [] + if dfa_Bs is not None: + with torch.no_grad(): + e_T = logits.softmax(dim=-1) + e_T[torch.arange(x_eval.size(0)), y_eval] -= 1.0 + for l in range(L): + a_dfa = (e_T @ dfa_Bs[l].T).detach() + per_layer_gamma.append(cosine_similarity_batch(a_dfa, bp_full[l])) + gamma_dfa = float(np.mean(per_layer_gamma)) + + if was_training: + model.train() + return { + 'hidden_norms': hidden_norms, + 'bp_grad_per_sample_l2_med': bp_grad_l2, + 'bp_grad_F': bp_grad_F, + 'gamma_dfa': gamma_dfa, + 'gamma_dfa_per_layer': per_layer_gamma, + 'acc_eval': acc, + 'loss_eval': loss_val, + } + + +def train_bp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd): + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + log = [] + d0 = diagnose_synth(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0) + print(f" [BP] Ep 0: ||h_L||={d0['hidden_norms'][-1]:.3e} ||g||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + logits = model(x) + loss = F.cross_entropy(logits, y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + d = diagnose_synth(model, x_eval, y_eval); d['epoch'] = ep; log.append(d) + if ep % 5 == 0 or ep in (1, epochs): + print(f" [BP] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} ||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f}", flush=True) + return log + + +def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd): + d_hidden = model.d_hidden + L = model.num_blocks + C = 10 + Bs = [torch.randn(d_hidden, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = [] + d0 = diagnose_synth(model, x_eval, y_eval, dfa_Bs=Bs); d0['epoch'] = 0; log.append(d0) + print(f" [DFA] Ep 0: ||h_L||={d0['hidden_norms'][-1]:.3e} ||g||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + hL_det = hiddens[-1].detach() + # head update via direct CE on head(hL) + logits_out = model.out_head(hL_det) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad(); loss_out.backward(); head_opt.step() + # block updates via DFA local credit + for l in range(L): + h_l = hiddens[l].detach() + a_dfa = (e_T @ Bs[l].T).detach() + rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_dfa / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad(); local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + for s in all_sch: + s.step() + d = diagnose_synth(model, x_eval, y_eval, dfa_Bs=Bs); d['epoch'] = ep; log.append(d) + if ep % 5 == 0 or ep in (1, epochs): + print(f" [DFA] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} ||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f} γ_dfa={d['gamma_dfa']:.4f}", flush=True) + return log + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--output_dir', type=str, default='results/snapshot_synth_v1') + p.add_argument('--epochs', type=int, default=80) + p.add_argument('--alpha', type=float, default=1.0) + p.add_argument('--depth', type=int, default=4) + p.add_argument('--seed', type=int, default=42) + p.add_argument('--d_hidden', type=int, default=128) + p.add_argument('--lr', type=float, default=1e-3) + p.add_argument('--wd', type=float, default=0.01) + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device('cuda:0') + print(f"device={device}, alpha={args.alpha}, depth={args.depth}, " + f"d_hidden={args.d_hidden}, epochs={args.epochs}, seed={args.seed}", flush=True) + + set_seed(args.seed) + L, d, C = args.depth, args.d_hidden, 10 + teacher = TeacherNet(d, L, C, args.alpha, seed=0).to(device) + + n_train = 50 * 256 + n_test = 2000 + X_tr, Y_tr = generate_synth_dataset(teacher, n_train, d, device, seed=args.seed) + X_te, Y_te = generate_synth_dataset(teacher, n_test, d, device, seed=args.seed + 10000) + train_loader = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=256, shuffle=True) + x_eval, y_eval = X_te.to(device), Y_te.to(device) + print(f"train: {X_tr.shape}, test eval buffer: {x_eval.shape}", flush=True) + + print("\n=== BP training ===", flush=True) + set_seed(args.seed) + bp_model = StudentNet(d, C, L, args.alpha).to(device) + bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd) + + print("\n=== DFA training ===", flush=True) + set_seed(args.seed) + dfa_model = StudentNet(d, C, L, args.alpha).to(device) + dfa_log = train_dfa(dfa_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd) + + out = { + 'config': vars(args), + 'depth': L, 'd_hidden': d, 'num_classes': C, + 'bp_log': bp_log, + 'dfa_log': dfa_log, + } + out_path = os.path.join(args.output_dir, f'snapshot_synth_a{args.alpha}_L{L}_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(out, f, indent=2) + print(f"\nSaved {out_path}", flush=True) + + +if __name__ == '__main__': + main() diff --git a/experiments/vit_frozen_blocks_baseline.py b/experiments/vit_frozen_blocks_baseline.py new file mode 100644 index 0000000..8b53198 --- /dev/null +++ b/experiments/vit_frozen_blocks_baseline.py @@ -0,0 +1,177 @@ +""" +Frozen-random-blocks baseline for ViT-Mini: train BP and DFA where the 4 +transformer blocks are randomly initialized and FROZEN (no parameter updates). +Only patch_embed + cls_token + pos_embed + out_ln + out_head are trainable. + +This is the codex-round-6 control for the "DFA actually trains the transformer +blocks" claim. If frozen-blocks DFA gets ≈ 24% (matching the trainable-blocks +4-block ViT-Mini DFA acc), then the blocks are passengers — DFA's "24%" is +coming from patch_embed + head learning routed via untrained block mixing. +If frozen-blocks DFA stays much lower than 24%, then the trainable blocks +are doing learned work. + +Usage: + CUDA_VISIBLE_DEVICES=2 python experiments/vit_frozen_blocks_baseline.py +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import numpy as np + +from models.vit_mini import ViTMini + + +def get_loaders(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return ( + DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), + ) + + +def evaluate(model, loader, dev): + model.eval() + n = c = 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(dev), y.to(dev) + preds = model(x).argmax(-1) + c += (preds == y).sum().item() + n += x.size(0) + return c / n + + +def freeze_blocks(model): + for p in model.blocks.parameters(): + p.requires_grad_(False) + model.blocks.eval() + + +def train_bp_frozen(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05): + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + m = ViTMini(num_blocks=4, d_model=128, n_heads=4).to(dev) + freeze_blocks(m) + n_trainable = sum(p.numel() for p in m.parameters() if p.requires_grad) + n_total = sum(p.numel() for p in m.parameters()) + print(f"BP-frozen-blocks: {n_trainable}/{n_total} params trainable", flush=True) + opt = optim.AdamW(filter(lambda p: p.requires_grad, m.parameters()), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + m.train() + m.blocks.eval() # keep blocks in eval mode (no dropout etc) + for x, y in train_loader: + x = x.to(dev); y = y.to(dev) + loss = F.cross_entropy(m(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 5 == 0 or ep == 1 or ep == epochs: + acc = evaluate(m, test_loader, dev) + print(f" BP-frozen ep {ep}: test_acc={acc:.4f}", flush=True) + return m + + +def train_dfa_frozen(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05): + """4 transformer blocks frozen at random init. + Trainable: patch_embed, cls_token, pos_embed, out_ln, out_head. + DFA-style: head with true CE on cls token; embed (patch+cls+pos) with random feedback.""" + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + m = ViTMini(num_blocks=4, d_model=128, n_heads=4).to(dev) + freeze_blocks(m) + n_trainable = sum(p.numel() for p in m.parameters() if p.requires_grad) + n_total = sum(p.numel() for p in m.parameters()) + print(f"DFA-frozen-blocks: {n_trainable}/{n_total} params trainable", flush=True) + + d_model, C = 128, 10 + B0 = torch.randn(d_model, C, device=dev) / np.sqrt(C) + embed_opt = optim.AdamW( + list(m.patch_embed.parameters()) + [m.cls_token, m.pos_embed], + lr=lr, weight_decay=wd + ) + head_opt = optim.AdamW( + list(m.out_head.parameters()) + list(m.out_ln.parameters()), + lr=lr, weight_decay=wd + ) + sch1 = optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs) + sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs) + for ep in range(1, epochs + 1): + m.train() + m.blocks.eval() + for x, y in train_loader: + x = x.to(dev); y = y.to(dev) + with torch.no_grad(): + logits, hi = m(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1 + hL_det = hi[-1].detach() + # Head update via true CE on cls token + h_cls = m.out_ln(hL_det[:, 0]) + head_opt.zero_grad() + F.cross_entropy(m.out_head(h_cls), y).backward() + head_opt.step() + # Embed update via DFA feedback + a0 = (e_T @ B0.T).detach() + rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = m.embed(x) + a0_b = a0.unsqueeze(1).expand_as(h0) + embed_loss = (h0 * (a0_b / rms.unsqueeze(1))).sum(-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + sch1.step(); sch2.step() + if ep % 5 == 0 or ep == 1 or ep == epochs: + acc = evaluate(m, test_loader, dev) + print(f" DFA-frozen ep {ep}: test_acc={acc:.4f}", flush=True) + return m + + +def main(): + import argparse + p = argparse.ArgumentParser() + p.add_argument('--seed', type=int, default=42) + p.add_argument('--epochs', type=int, default=30) + args = p.parse_args() + + dev = torch.device('cuda:0') + print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True) + train_loader, test_loader = get_loaders(batch_size=128) + + print(f"\n=== BP frozen-blocks baseline (4 random-init transformer blocks, frozen), seed={args.seed} ===", flush=True) + mb = train_bp_frozen(train_loader, test_loader, dev, epochs=args.epochs, seed=args.seed) + bp_acc = evaluate(mb, test_loader, dev) + print(f"FINAL BP-frozen-blocks acc: {bp_acc:.4f}", flush=True) + + print(f"\n=== DFA frozen-blocks baseline, seed={args.seed} ===", flush=True) + md = train_dfa_frozen(train_loader, test_loader, dev, epochs=args.epochs, seed=args.seed) + dfa_acc = evaluate(md, test_loader, dev) + print(f"FINAL DFA-frozen-blocks acc: {dfa_acc:.4f}", flush=True) + + print(f"\n=== Summary ===") + print(f"BP-frozen-blocks: {bp_acc:.4f} (chance=0.10)") + print(f"DFA-frozen-blocks: {dfa_acc:.4f}") + print(f"Compare to ViT-Mini 4-block trainable (3-seed avg): BP=0.792, DFA=0.237") + print(f"Compare to ViT-Mini 0-block (shallow baseline): BP=0.10, DFA=0.10") + print() + print("Interpretation:") + print(" If DFA-frozen-blocks ≈ 0.237: blocks are passengers, DFA is just learning patch_embed+head") + print(" If DFA-frozen-blocks << 0.237: trainable blocks ARE doing learned work") + print(" If DFA-frozen-blocks ~ 0.10: untrained blocks add no useful mixing (less informative)") + + +if __name__ == '__main__': + main() diff --git a/experiments/vit_shallow_baseline.py b/experiments/vit_shallow_baseline.py new file mode 100644 index 0000000..c030d74 --- /dev/null +++ b/experiments/vit_shallow_baseline.py @@ -0,0 +1,147 @@ +""" +Shallow baseline for ViT-Mini: train BP and DFA on a 0-block ViT (just patch_embed ++ cls + pos + out_ln + out_head), to test whether the DFA accuracy on the full +ViT is just exploiting the patch embedder + head. + +This is the codex-round-5 control for the "DFA actually trains the transformer +blocks" claim. If shallow DFA acc ≈ 24% (matching the 4-block ViT-Mini DFA acc), +then the blocks are passengers and the claim is too strong. If shallow DFA acc +is much lower, then the blocks are doing real work. + +Usage: + CUDA_VISIBLE_DEVICES=2 python experiments/vit_shallow_baseline.py +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import numpy as np + +from models.vit_mini import ViTMini + + +def get_loaders(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return ( + DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), + ) + + +def evaluate(model, loader, dev): + model.eval() + n = c = 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(dev), y.to(dev) + preds = model(x).argmax(-1) + c += (preds == y).sum().item() + n += x.size(0) + return c / n + + +def train_bp_shallow(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05): + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + m = ViTMini(num_blocks=0, d_model=128, n_heads=4).to(dev) + print(f"BP-shallow: n_params={sum(p.numel() for p in m.parameters())}", flush=True) + opt = optim.AdamW(m.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + m.train() + for x, y in train_loader: + x = x.to(dev); y = y.to(dev) + loss = F.cross_entropy(m(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 5 == 0 or ep == 1 or ep == epochs: + acc = evaluate(m, test_loader, dev) + print(f" BP-shallow ep {ep}: test_acc={acc:.4f}", flush=True) + return m + + +def train_dfa_shallow(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05): + """0-block ViT trained DFA-style: head with true CE on cls token, + embed (patch_embed + cls + pos) with random feedback `e_T @ B^T` from the head.""" + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + m = ViTMini(num_blocks=0, d_model=128, n_heads=4).to(dev) + print(f"DFA-shallow: n_params={sum(p.numel() for p in m.parameters())}", flush=True) + d_model, C = 128, 10 + B0 = torch.randn(d_model, C, device=dev) / np.sqrt(C) + embed_opt = optim.AdamW( + list(m.patch_embed.parameters()) + [m.cls_token, m.pos_embed], + lr=lr, weight_decay=wd + ) + head_opt = optim.AdamW( + list(m.out_head.parameters()) + list(m.out_ln.parameters()), + lr=lr, weight_decay=wd + ) + sch1 = optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs) + sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs) + for ep in range(1, epochs + 1): + m.train() + for x, y in train_loader: + x = x.to(dev); y = y.to(dev) + with torch.no_grad(): + logits, hi = m(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1 + hL_det = hi[-1].detach() + # Head update via true CE on cls token + h_cls = m.out_ln(hL_det[:, 0]) + head_opt.zero_grad() + F.cross_entropy(m.out_head(h_cls), y).backward() + head_opt.step() + # Embed update via DFA-style local loss + a0 = (e_T @ B0.T).detach() + rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = m.embed(x) # (B, 65, d_model) + a0_b = a0.unsqueeze(1).expand_as(h0) + embed_loss = (h0 * (a0_b / rms.unsqueeze(1))).sum(-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + sch1.step(); sch2.step() + if ep % 5 == 0 or ep == 1 or ep == epochs: + acc = evaluate(m, test_loader, dev) + print(f" DFA-shallow ep {ep}: test_acc={acc:.4f}", flush=True) + return m + + +def main(): + dev = torch.device('cuda:0') + print(f"Device: {dev}", flush=True) + train_loader, test_loader = get_loaders(batch_size=128) + + print("\n=== BP shallow baseline (ViT-Mini num_blocks=0) ===", flush=True) + mb = train_bp_shallow(train_loader, test_loader, dev, epochs=30, seed=42) + bp_acc = evaluate(mb, test_loader, dev) + print(f"FINAL BP-shallow acc: {bp_acc:.4f}", flush=True) + + print("\n=== DFA shallow baseline (ViT-Mini num_blocks=0) ===", flush=True) + md = train_dfa_shallow(train_loader, test_loader, dev, epochs=30, seed=42) + dfa_acc = evaluate(md, test_loader, dev) + print(f"FINAL DFA-shallow acc: {dfa_acc:.4f}", flush=True) + + print(f"\n=== Summary ===") + print(f"BP-shallow: {bp_acc:.4f} (chance=0.10)") + print(f"DFA-shallow: {dfa_acc:.4f}") + print(f"Compare to ViT-Mini 4-block (3-seed avg): BP=0.792, DFA=0.237") + + +if __name__ == '__main__': + main() |
