summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-06-14 04:06:32 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-06-14 04:06:32 -0500
commitaa73718eb6427d7da3b9cb416275802d90c4b2ed (patch)
treeb68b0a664fb650744ef934a1c22abd740a7b62a6 /experiments
parent827c658fa9a750f3c6ebdb87703762f10f69f6ff (diff)
Add new experiment scripts, figures, and paper assets; untrack pyc/build artifactsHEADmaster
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/__pycache__/__init__.cpython-313.pycbin137 -> 0 bytes
-rw-r--r--experiments/__pycache__/toy_lq.cpython-313.pycbin19620 -> 0 bytes
-rw-r--r--experiments/analyze_snapshot_evolution.py60
-rw-r--r--experiments/dfa_penalty_freshB.py183
-rw-r--r--experiments/dfa_penalty_trajectory.py135
-rw-r--r--experiments/figure_snapshot_evolution.py178
-rw-r--r--experiments/frozen_baselines_crossarch.py191
-rw-r--r--experiments/resnet_frozen_blocks_baseline.py278
-rw-r--r--experiments/resnet_protocol_validation.py343
-rw-r--r--experiments/snapshot_compare_outln.py93
-rw-r--r--experiments/snapshot_evolution_no_outln.py249
-rw-r--r--experiments/snapshot_evolution_residual_explosion.py78
-rw-r--r--experiments/snapshot_evolution_vit.py244
-rw-r--r--experiments/snapshot_fa_crossarch.py243
-rw-r--r--experiments/snapshot_fa_only.py38
-rw-r--r--experiments/snapshot_fa_studentnet.py94
-rw-r--r--experiments/snapshot_synth_residual_explosion.py195
-rw-r--r--experiments/vit_frozen_blocks_baseline.py177
-rw-r--r--experiments/vit_shallow_baseline.py147
19 files changed, 2926 insertions, 0 deletions
diff --git a/experiments/__pycache__/__init__.cpython-313.pyc b/experiments/__pycache__/__init__.cpython-313.pyc
deleted file mode 100644
index 5966841..0000000
--- a/experiments/__pycache__/__init__.cpython-313.pyc
+++ /dev/null
Binary files differ
diff --git a/experiments/__pycache__/toy_lq.cpython-313.pyc b/experiments/__pycache__/toy_lq.cpython-313.pyc
deleted file mode 100644
index d8710a8..0000000
--- a/experiments/__pycache__/toy_lq.cpython-313.pyc
+++ /dev/null
Binary files differ
diff --git a/experiments/analyze_snapshot_evolution.py b/experiments/analyze_snapshot_evolution.py
new file mode 100644
index 0000000..8b9f8af
--- /dev/null
+++ b/experiments/analyze_snapshot_evolution.py
@@ -0,0 +1,60 @@
+"""
+Read snapshot evolution JSONs (BP vs DFA over training epochs), summarize and
+print comparison tables. Used for the P4 paper figure.
+
+Usage:
+ python experiments/analyze_snapshot_evolution.py <json_path>
+"""
+import sys, json
+import numpy as np
+
+
+def summarize(log, name):
+ eps = [d['epoch'] for d in log]
+ h_L = [d['hidden_norms'][-1] for d in log]
+ g_l2 = [d['bp_grad_per_sample_l2_med'][2] if 'bp_grad_per_sample_l2_med' in d
+ else d['bp_grad_norms_per_sample_med'][2] for d in log]
+ acc = [d['acc_eval'] for d in log]
+ print(f"\n{name} ({len(log)} epochs):")
+ print(f" ||h_L||_2 median: ep0={h_L[0]:.3e} -> ep{eps[len(eps)//2]}={h_L[len(eps)//2]:.3e} -> ep{eps[-1]}={h_L[-1]:.3e}")
+ print(f" ||BP grad at h_2||_2 median: ep0={g_l2[0]:.3e} -> ep{eps[len(eps)//2]}={g_l2[len(eps)//2]:.3e} -> ep{eps[-1]}={g_l2[-1]:.3e}")
+ print(f" acc: ep0={acc[0]:.4f} -> ep{eps[-1]}={acc[-1]:.4f}")
+ print(f" ||h_L|| growth (final/initial): {h_L[-1]/max(h_L[0], 1e-12):.3e}")
+ print(f" ||BP_g|| change (final/initial): {g_l2[-1]/max(g_l2[0], 1e-30):.3e}")
+
+
+def main():
+ path = sys.argv[1] if len(sys.argv) > 1 else 'results/snapshot_evolution_v2/snapshot_evolution_s42.json'
+ with open(path) as f:
+ d = json.load(f)
+ print(f"Loaded {path}")
+ print(f"config: {d.get('config', {})}")
+ print(f"depth={d.get('depth')}, d_hidden={d.get('d_hidden')}")
+ if 'bp_log' in d:
+ summarize(d['bp_log'], 'BP')
+ if 'dfa_log' in d:
+ summarize(d['dfa_log'], 'DFA')
+
+ # Print compact per-epoch comparison if both available
+ if 'bp_log' in d and 'dfa_log' in d:
+ bp = d['bp_log']
+ dfa = d['dfa_log']
+ eps = sorted(set([x['epoch'] for x in bp]) & set([x['epoch'] for x in dfa]))
+ sample_eps = [eps[i] for i in [0, len(eps)//4, len(eps)//2, 3*len(eps)//4, -1]]
+ print(f"\nPer-epoch sample (BP vs DFA):")
+ print(f"{'epoch':>6s} {'BP_||h_L||':>12s} {'DFA_||h_L||':>12s} {'BP_||g_2||':>12s} {'DFA_||g_2||':>12s} {'BP_acc':>8s} {'DFA_acc':>8s}")
+ bp_d = {x['epoch']: x for x in bp}
+ dfa_d = {x['epoch']: x for x in dfa}
+ for e in sample_eps:
+ bdat = bp_d[e]
+ ddat = dfa_d[e]
+ bh = bdat['hidden_norms'][-1]
+ dh = ddat['hidden_norms'][-1]
+ bg_key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in bdat else 'bp_grad_norms_per_sample_med'
+ bg = bdat[bg_key][2]
+ dg = ddat[bg_key][2]
+ print(f"{e:>6d} {bh:>12.3e} {dh:>12.3e} {bg:>12.3e} {dg:>12.3e} {bdat['acc_eval']:>8.4f} {ddat['acc_eval']:>8.4f}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/dfa_penalty_freshB.py b/experiments/dfa_penalty_freshB.py
new file mode 100644
index 0000000..82b192d
--- /dev/null
+++ b/experiments/dfa_penalty_freshB.py
@@ -0,0 +1,183 @@
+"""
+DFA canonical λ=1e-2 training + checkpoint save + fresh-B null calibration.
+Runs after the main penalty sweep to produce the null calibration on the canonical checkpoint.
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+from metrics.credit_metrics import cosine_similarity_batch
+
+
+def get_data(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))
+
+
+def train_dfa_canonical(model, train_loader, device, epochs, lr, wd, penalty_lam):
+ """Canonical DFA from cifar_resmlp.py: no grad clipping, mean reduction."""
+ d = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)])
+
+ for epoch in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a_dfa / rms)).sum(dim=-1).mean()
+ if penalty_lam > 0:
+ local_loss = local_loss + penalty_lam * (f_l ** 2).sum(dim=-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step()
+ a_0 = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ if epoch % 10 == 0 or epoch == epochs:
+ print(f" [DFA pen] ep {epoch}", flush=True)
+ return Bs
+
+
+def compute_deep_cosine(model, Bs, x_eval, y_eval, device):
+ """Compute per-layer DFA cosine on eval buffer."""
+ model.eval()
+ L = model.num_blocks
+ h0 = model.embed(x_eval.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ logits = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ with torch.no_grad():
+ e_T = logits.softmax(-1)
+ e_T[torch.arange(x_eval.size(0)), y_eval] -= 1
+ cos_per_layer = []
+ for l in range(L):
+ a_dfa = (e_T @ Bs[l].T).detach()
+ cos_per_layer.append(cosine_similarity_batch(a_dfa, grads[l].detach()))
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ g_norms = [g.norm(dim=-1).median().item() for g in grads]
+ h_norms = [h.detach().norm(dim=-1).median().item() for h in hs]
+ return cos_per_layer, acc, g_norms, h_norms
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--output_dir', type=str, default='results/dfa_canonical_freshB')
+ p.add_argument('--n_fresh', type=int, default=20)
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device('cuda:0')
+ train_loader, test_loader = get_data(128)
+
+ # Fixed eval buffer
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 128:
+ break
+ x_eval = torch.cat(xs)[:128].to(device)
+ y_eval = torch.cat(ys)[:128].to(device)
+
+ L, d, C = 4, 256, 10
+
+ # Train DFA with λ=1e-2
+ print(f"Training DFA canonical λ=0.01, seed={args.seed}", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ model = ResidualMLP(3072, d, C, L).to(device)
+ training_Bs = train_dfa_canonical(model, train_loader, device, 30, 1e-3, 0.01, 0.01)
+
+ # Save checkpoint
+ ckpt_path = os.path.join(args.output_dir, f'dfa_canonical_lam0.01_s{args.seed}.pt')
+ torch.save({'state_dict': model.state_dict(),
+ 'Bs': [B.cpu() for B in training_Bs],
+ 'seed': args.seed}, ckpt_path)
+ print(f"Saved checkpoint: {ckpt_path}", flush=True)
+
+ # Compute cosine with training Bs
+ cos_training, acc, g_norms, h_norms = compute_deep_cosine(model, training_Bs, x_eval, y_eval, device)
+ deep_cos_training = float(np.mean(cos_training[1:])) # exclude layer 0
+ print(f"Training-Bs: acc={acc:.4f}, deep cos={deep_cos_training:+.4f}")
+ print(f" per-layer cos: {[f'{c:+.4f}' for c in cos_training]}")
+ print(f" ||g_l||: {[f'{g:.2e}' for g in g_norms]}")
+ print(f" ||h_l||: {[f'{h:.2e}' for h in h_norms]}")
+
+ # Fresh-B null calibration
+ print(f"\nFresh-B null calibration ({args.n_fresh} draws)...", flush=True)
+ fresh_deep_cos = []
+ fresh_per_layer = []
+ for i in range(args.n_fresh):
+ fresh_Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ cos_fresh, _, _, _ = compute_deep_cosine(model, fresh_Bs, x_eval, y_eval, device)
+ deep_fresh = float(np.mean(cos_fresh[1:]))
+ fresh_deep_cos.append(deep_fresh)
+ fresh_per_layer.append(cos_fresh)
+ fresh_mean = np.mean(fresh_deep_cos)
+ fresh_std_ddof1 = np.std(fresh_deep_cos, ddof=1)
+ print(f"Fresh-Bs deep cos: {fresh_mean:+.4f} ± {fresh_std_ddof1:.4f} (ddof=1)")
+
+ # Save results
+ out = {
+ 'description': f'Canonical DFA λ=0.01 s={args.seed} + fresh-B null (N={args.n_fresh})',
+ 'training_Bs_deep_cos': deep_cos_training,
+ 'training_Bs_per_layer_cos': cos_training,
+ 'training_Bs_acc': acc,
+ 'training_Bs_g_norms': g_norms,
+ 'training_Bs_h_norms': h_norms,
+ 'fresh_Bs_n_draws': args.n_fresh,
+ 'fresh_Bs_deep_cos_per_draw': fresh_deep_cos,
+ 'fresh_Bs_deep_mean': fresh_mean,
+ 'fresh_Bs_deep_std_ddof1': fresh_std_ddof1,
+ 'fresh_Bs_per_layer_mean': [float(np.mean([fl[l] for fl in fresh_per_layer])) for l in range(L)],
+ }
+ out_path = os.path.join(args.output_dir, f'freshB_null_canonical_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(out, f, indent=2)
+ print(f"Saved: {out_path}", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/dfa_penalty_trajectory.py b/experiments/dfa_penalty_trajectory.py
new file mode 100644
index 0000000..c46ce0b
--- /dev/null
+++ b/experiments/dfa_penalty_trajectory.py
@@ -0,0 +1,135 @@
+"""
+Canonical DFA penalty trajectory: per-epoch ||h_L|| and ||g_L|| for λ ∈ {0, 1e-4, 1e-2}.
+3 seeds × 3 λ × 30 epochs. Uses canonical cifar_resmlp.py DFA implementation (no clipping, mean reduction).
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+
+
+def get_data(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))
+
+
+def diagnose_quick(model, x_eval, y_eval):
+ model.eval()
+ x_flat = x_eval.view(x_eval.size(0), -1)
+ with torch.no_grad():
+ logits, hiddens = model(x_flat, return_hidden=True)
+ h_L = hiddens[-1].norm(dim=-1).median().item()
+ # BP grad at h_L
+ h0 = model.embed(x_flat.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ logits2 = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(logits2, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ g_L = grads[-1].norm(dim=-1).median().item()
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ model.train()
+ return h_L, g_L, acc
+
+
+def train_dfa_trajectory(seed, train_loader, x_eval, y_eval, device, epochs, lam):
+ L, d, C = 4, 256, 10
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = ResidualMLP(3072, d, C, L).to(device)
+ Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(block.parameters(), lr=1e-3, weight_decay=0.01) for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=1e-3, weight_decay=0.01)
+ all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)])
+
+ log = []
+ h_L, g_L, acc = diagnose_quick(model, x_eval, y_eval)
+ log.append({'epoch': 0, 'h_L': h_L, 'g_L': g_L, 'acc': acc})
+
+ for epoch in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL_det))
+ head_opt.zero_grad(); F.cross_entropy(logits_out, y).backward(); head_opt.step()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a_dfa / rms)).sum(dim=-1).mean()
+ if lam > 0:
+ local_loss = local_loss + lam * (f_l ** 2).sum(dim=-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step()
+ a_0 = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ h_L, g_L, acc = diagnose_quick(model, x_eval, y_eval)
+ log.append({'epoch': epoch, 'h_L': h_L, 'g_L': g_L, 'acc': acc})
+ if epoch % 10 == 0 or epoch == epochs:
+ print(f" [lam={lam}] s={seed} ep {epoch}: ||h_L||={h_L:.3e} ||g_L||={g_L:.3e} acc={acc:.4f}", flush=True)
+ return log
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--output', type=str, default='results/dfa_canonical_penalty_trajectory.json')
+ args = p.parse_args()
+
+ device = torch.device('cuda:0')
+ train_loader, test_loader = get_data(128)
+ # Fixed 128-sample eval buffer (consistent with cifar_resmlp.py compute_diagnostics)
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 128:
+ break
+ x_eval = torch.cat(xs)[:128].to(device)
+ y_eval = torch.cat(ys)[:128].to(device)
+
+ results = {}
+ for lam in [0.0, 1e-4, 1e-2]:
+ lam_key = f'lam_{lam}'
+ results[lam_key] = {}
+ for seed in [42, 123, 456]:
+ print(f"\n=== λ={lam}, seed={seed} ===", flush=True)
+ log = train_dfa_trajectory(seed, train_loader, x_eval, y_eval, device, 30, lam)
+ results[lam_key][str(seed)] = log
+
+ with open(args.output, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved: {args.output}", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/figure_snapshot_evolution.py b/experiments/figure_snapshot_evolution.py
new file mode 100644
index 0000000..b06f417
--- /dev/null
+++ b/experiments/figure_snapshot_evolution.py
@@ -0,0 +1,178 @@
+"""
+Generate the snapshot-evolution figure(s) for the paper from existing JSONs.
+
+Produces:
+ - figure_snapshot_resmlp.pdf : ResMLP with vs without out_ln, ||h_L|| and ||g||
+ over epochs for BP and DFA
+ - figure_snapshot_vit.pdf : ViT-Mini ||h_L|| and ||g|| over epochs for BP/DFA
+
+Usage:
+ python experiments/figure_snapshot_evolution.py
+"""
+import os, sys, json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+
+
+def load_log(path, log_key):
+ if not os.path.exists(path):
+ return None
+ with open(path) as f:
+ return json.load(f).get(log_key)
+
+
+def trajectory(log, metric):
+ """Extract a per-epoch trajectory for the given metric."""
+ eps = [r['epoch'] for r in log]
+ if metric == 'h_L':
+ # last hidden norm — handles both ResMLP (hidden_norms) and ViT (hidden_norms_cls)
+ values = []
+ for r in log:
+ if 'hidden_norms_cls' in r:
+ values.append(r['hidden_norms_cls'][-1])
+ else:
+ values.append(r['hidden_norms'][-1])
+ elif metric == 'g_2':
+ values = []
+ for r in log:
+ key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in r else 'bp_grad_norms_per_sample_med'
+ values.append(r[key][2])
+ elif metric == 'acc':
+ values = [r['acc_eval'] for r in log]
+ elif metric == 'gamma_dfa':
+ values = [r.get('gamma_dfa', float('nan')) for r in log]
+ else:
+ return None, None
+ return np.array(eps), np.array(values)
+
+
+def make_resmlp_figure(out_path):
+ fig, axes = plt.subplots(2, 2, figsize=(10, 7), sharex=True)
+
+ runs = {
+ 'with_out_ln_s42': 'results/snapshot_evolution_v2/snapshot_evolution_s42.json',
+ 'no_out_ln_s42': 'results/snapshot_no_outln_v1/snapshot_noLN_s42.json',
+ 'no_out_ln_s123': 'results/snapshot_no_outln_v1/snapshot_noLN_s123.json',
+ 'no_out_ln_s456': 'results/snapshot_no_outln_v1/snapshot_noLN_s456.json',
+ }
+ runs_loaded = {k: (load_log(v, 'bp_log'), load_log(v, 'dfa_log')) for k, v in runs.items()}
+
+ # Top row: with out_ln
+ bp, dfa = runs_loaded['with_out_ln_s42']
+ ax = axes[0, 0]
+ e, v = trajectory(bp, 'h_L'); ax.plot(e, v, 'b-', label='BP', lw=2)
+ e, v = trajectory(dfa, 'h_L'); ax.plot(e, v, 'r-', label='DFA', lw=2)
+ ax.set_yscale('log'); ax.set_ylabel(r'$\|h_L\|_2$ (median)')
+ ax.set_title('ResMLP with terminal LayerNorm (s42)')
+ ax.legend(); ax.grid(True, alpha=0.3)
+
+ ax = axes[0, 1]
+ e, v = trajectory(bp, 'g_2'); ax.plot(e, v, 'b-', label='BP', lw=2)
+ e, v = trajectory(dfa, 'g_2'); ax.plot(e, v, 'r-', label='DFA', lw=2)
+ ax.set_yscale('log'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)')
+ ax.set_title('ResMLP with terminal LayerNorm (s42)')
+ ax.legend(); ax.grid(True, alpha=0.3)
+
+ # Bottom row: no out_ln, mean ± std across 3 seeds
+ no_ln_bp_h = []; no_ln_bp_g = []; no_ln_dfa_h = []; no_ln_dfa_g = []
+ for k in ['no_out_ln_s42', 'no_out_ln_s123', 'no_out_ln_s456']:
+ bp, dfa = runs_loaded[k]
+ if bp is None or dfa is None: continue
+ e_bp, h_bp = trajectory(bp, 'h_L'); _, g_bp = trajectory(bp, 'g_2')
+ e_dfa, h_dfa = trajectory(dfa, 'h_L'); _, g_dfa = trajectory(dfa, 'g_2')
+ no_ln_bp_h.append(h_bp); no_ln_bp_g.append(g_bp)
+ no_ln_dfa_h.append(h_dfa); no_ln_dfa_g.append(g_dfa)
+
+ if no_ln_bp_h:
+ eps = e_bp
+ bp_h_arr = np.array(no_ln_bp_h)
+ bp_g_arr = np.array(no_ln_bp_g)
+ dfa_h_arr = np.array(no_ln_dfa_h)
+ dfa_g_arr = np.array(no_ln_dfa_g)
+
+ ax = axes[1, 0]
+ ax.plot(eps, np.mean(bp_h_arr, 0), 'b-', label='BP', lw=2)
+ ax.fill_between(eps, np.mean(bp_h_arr, 0)-np.std(bp_h_arr, 0), np.mean(bp_h_arr, 0)+np.std(bp_h_arr, 0), color='b', alpha=0.2)
+ ax.plot(eps, np.mean(dfa_h_arr, 0), 'r-', label='DFA', lw=2)
+ ax.fill_between(eps, np.mean(dfa_h_arr, 0)-np.std(dfa_h_arr, 0), np.mean(dfa_h_arr, 0)+np.std(dfa_h_arr, 0), color='r', alpha=0.2)
+ ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|h_L\|_2$ (median)')
+ ax.set_title(f'ResMLP WITHOUT terminal LayerNorm (mean ± std, n={len(no_ln_bp_h)})')
+ ax.legend(); ax.grid(True, alpha=0.3)
+
+ ax = axes[1, 1]
+ ax.plot(eps, np.mean(bp_g_arr, 0), 'b-', label='BP', lw=2)
+ ax.fill_between(eps, np.mean(bp_g_arr, 0)-np.std(bp_g_arr, 0), np.mean(bp_g_arr, 0)+np.std(bp_g_arr, 0), color='b', alpha=0.2)
+ ax.plot(eps, np.mean(dfa_g_arr, 0), 'r-', label='DFA', lw=2)
+ ax.fill_between(eps, np.mean(dfa_g_arr, 0)-np.std(dfa_g_arr, 0), np.mean(dfa_g_arr, 0)+np.std(dfa_g_arr, 0), color='r', alpha=0.2)
+ ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)')
+ ax.set_title(f'ResMLP WITHOUT terminal LayerNorm (mean ± std, n={len(no_ln_bp_h)})')
+ ax.legend(); ax.grid(True, alpha=0.3)
+
+ plt.suptitle('Snapshot evolution: residual stream + BP grad over training\n(top: with terminal LN — DFA explodes; bottom: no terminal LN — DFA still grows but BP grad does NOT collapse)', y=1.02)
+ plt.tight_layout()
+ plt.savefig(out_path, bbox_inches='tight', dpi=150)
+ print(f"Saved {out_path}")
+ plt.close()
+
+
+def make_vit_figure(out_path):
+ fig, axes = plt.subplots(1, 2, figsize=(11, 4))
+
+ runs = sorted([
+ f for f in os.listdir('results/snapshot_vit_v1')
+ if f.startswith('snapshot_vit_s') and f.endswith('.json')
+ ])
+ if not runs:
+ print("No ViT snapshot JSONs found")
+ return
+
+ bp_h_list = []; bp_g_list = []; dfa_h_list = []; dfa_g_list = []
+ eps = None
+ for r in runs:
+ path = f'results/snapshot_vit_v1/{r}'
+ bp = load_log(path, 'bp_log')
+ dfa = load_log(path, 'dfa_log')
+ if bp is None or dfa is None: continue
+ e_bp, h_bp = trajectory(bp, 'h_L'); _, g_bp = trajectory(bp, 'g_2')
+ e_dfa, h_dfa = trajectory(dfa, 'h_L'); _, g_dfa = trajectory(dfa, 'g_2')
+ bp_h_list.append(h_bp); bp_g_list.append(g_bp)
+ dfa_h_list.append(h_dfa); dfa_g_list.append(g_dfa)
+ eps = e_bp
+
+ bp_h_arr = np.array(bp_h_list); bp_g_arr = np.array(bp_g_list)
+ dfa_h_arr = np.array(dfa_h_list); dfa_g_arr = np.array(dfa_g_list)
+
+ ax = axes[0]
+ ax.plot(eps, np.mean(bp_h_arr, 0), 'b-', label='BP', lw=2)
+ if len(bp_h_list) > 1:
+ ax.fill_between(eps, np.mean(bp_h_arr, 0)-np.std(bp_h_arr, 0), np.mean(bp_h_arr, 0)+np.std(bp_h_arr, 0), color='b', alpha=0.2)
+ ax.plot(eps, np.mean(dfa_h_arr, 0), 'r-', label='DFA', lw=2)
+ if len(dfa_h_list) > 1:
+ ax.fill_between(eps, np.mean(dfa_h_arr, 0)-np.std(dfa_h_arr, 0), np.mean(dfa_h_arr, 0)+np.std(dfa_h_arr, 0), color='r', alpha=0.2)
+ ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|h_L^{cls}\|_2$ (median)')
+ ax.set_title(f'ViT-Mini, terminal LayerNorm (n={len(bp_h_list)})')
+ ax.legend(); ax.grid(True, alpha=0.3)
+
+ ax = axes[1]
+ ax.plot(eps, np.mean(bp_g_arr, 0), 'b-', label='BP', lw=2)
+ if len(bp_g_list) > 1:
+ ax.fill_between(eps, np.mean(bp_g_arr, 0)-np.std(bp_g_arr, 0), np.mean(bp_g_arr, 0)+np.std(bp_g_arr, 0), color='b', alpha=0.2)
+ ax.plot(eps, np.mean(dfa_g_arr, 0), 'r-', label='DFA', lw=2)
+ if len(dfa_g_list) > 1:
+ ax.fill_between(eps, np.mean(dfa_g_arr, 0)-np.std(dfa_g_arr, 0), np.mean(dfa_g_arr, 0)+np.std(dfa_g_arr, 0), color='r', alpha=0.2)
+ ax.set_yscale('log'); ax.set_xlabel('epoch'); ax.set_ylabel(r'$\|\nabla_{h_2} L\|_2$ (BP grad, median)')
+ ax.set_title(f'ViT-Mini, terminal LayerNorm (n={len(bp_g_list)})')
+ ax.legend(); ax.grid(True, alpha=0.3)
+
+ plt.tight_layout()
+ plt.savefig(out_path, bbox_inches='tight', dpi=150)
+ print(f"Saved {out_path}")
+ plt.close()
+
+
+if __name__ == '__main__':
+ os.makedirs('results/figures', exist_ok=True)
+ make_resmlp_figure('results/figures/figure_snapshot_resmlp.pdf')
+ make_vit_figure('results/figures/figure_snapshot_vit.pdf')
diff --git a/experiments/frozen_baselines_crossarch.py b/experiments/frozen_baselines_crossarch.py
new file mode 100644
index 0000000..a3dd76c
--- /dev/null
+++ b/experiments/frozen_baselines_crossarch.py
@@ -0,0 +1,191 @@
+"""
+Frozen-blocks baselines for ViT-Mini and StudentNet.
+Trains only embed/head/LN with blocks frozen at random init.
+Also trains shallow (no blocks) variant for comparison.
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.vit_mini import ViTMini
+from experiments.confirmatory_paper_experiments import (
+ StudentNet, TeacherNet, generate_synth_dataset, set_seed
+)
+
+
+def get_cifar10(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))
+
+
+def evaluate(model, loader, device, is_vit=False):
+ model.eval()
+ c = n = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.to(device); y = y.to(device)
+ if not is_vit:
+ x = x.view(x.size(0), -1) if x.dim() == 4 else x
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def freeze_blocks(model):
+ for p in model.blocks.parameters():
+ p.requires_grad_(False)
+
+
+# ─── ViT-Mini frozen/shallow ────────────────────────────────────────────
+
+def train_vit_frozen(seed, train_loader, test_loader, device, epochs, lr, wd):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = ViTMini(d_model=128, n_heads=4, num_blocks=4, num_classes=10).to(device)
+ freeze_blocks(model)
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ total = sum(p.numel() for p in model.parameters())
+ print(f" ViT-Mini frozen: {trainable}/{total} trainable params", flush=True)
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == epochs:
+ acc = evaluate(model, test_loader, device, is_vit=True)
+ print(f" [ViT-frozen] s={seed} ep {ep}: acc={acc:.4f}", flush=True)
+ return evaluate(model, test_loader, device, is_vit=True)
+
+
+def train_vit_shallow(seed, train_loader, test_loader, device, epochs, lr, wd):
+ """ViT with num_blocks=0: just patch_embed + cls + pos + LN + head."""
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = ViTMini(d_model=128, n_heads=4, num_blocks=0, num_classes=10).to(device)
+ trainable = sum(p.numel() for p in model.parameters())
+ print(f" ViT-Mini shallow: {trainable} params (no blocks)", flush=True)
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == epochs:
+ acc = evaluate(model, test_loader, device, is_vit=True)
+ print(f" [ViT-shallow] s={seed} ep {ep}: acc={acc:.4f}", flush=True)
+ return evaluate(model, test_loader, device, is_vit=True)
+
+
+# ─── StudentNet frozen/shallow ──────────────────────────────────────────
+
+def train_student_frozen(seed, train_loader, test_loader, device, epochs, lr, wd, alpha=1.0):
+ set_seed(seed)
+ model = StudentNet(128, 10, 4, alpha).to(device)
+ freeze_blocks(model)
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ total = sum(p.numel() for p in model.parameters())
+ print(f" StudentNet frozen: {trainable}/{total} trainable params", flush=True)
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == epochs:
+ acc = evaluate(model, test_loader, device)
+ print(f" [Student-frozen] s={seed} ep {ep}: acc={acc:.4f}", flush=True)
+ return evaluate(model, test_loader, device)
+
+
+def train_student_shallow(seed, train_loader, test_loader, device, epochs, lr, wd, alpha=1.0):
+ """StudentNet with num_blocks=0: just out_head (input is d_hidden already)."""
+ set_seed(seed)
+ model = StudentNet(128, 10, 0, alpha).to(device)
+ trainable = sum(p.numel() for p in model.parameters())
+ print(f" StudentNet shallow: {trainable} params (no blocks)", flush=True)
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == epochs:
+ acc = evaluate(model, test_loader, device)
+ print(f" [Student-shallow] s={seed} ep {ep}: acc={acc:.4f}", flush=True)
+ return evaluate(model, test_loader, device)
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--output', type=str, default='results/frozen_baselines_crossarch.json')
+ args = p.parse_args()
+
+ device = torch.device('cuda:0')
+
+ results = {}
+
+ # ── ViT-Mini (CIFAR-10, 60 epochs) ──
+ print("\n=== ViT-Mini frozen baselines ===", flush=True)
+ train_loader, test_loader = get_cifar10(128)
+ for seed in [42, 123, 456]:
+ print(f"\n--- ViT-Mini seed={seed} ---", flush=True)
+ frozen_acc = train_vit_frozen(seed, train_loader, test_loader, device, 60, 1e-3, 0.05)
+ shallow_acc = train_vit_shallow(seed, train_loader, test_loader, device, 60, 1e-3, 0.05)
+ results[f'vit_frozen_s{seed}'] = frozen_acc
+ results[f'vit_shallow_s{seed}'] = shallow_acc
+ print(f" FINAL ViT s={seed}: frozen={frozen_acc:.4f}, shallow={shallow_acc:.4f}", flush=True)
+
+ # ── StudentNet (synthetic, 80 epochs) ──
+ print("\n=== StudentNet frozen baselines ===", flush=True)
+ L, d, C, alpha = 4, 128, 10, 1.0
+ for seed in [42, 123, 456]:
+ print(f"\n--- StudentNet seed={seed} ---", flush=True)
+ set_seed(seed)
+ teacher = TeacherNet(d, L, C, alpha, seed=0).to(device)
+ X_tr, Y_tr = generate_synth_dataset(teacher, 50*256, d, device, seed=seed)
+ X_te, Y_te = generate_synth_dataset(teacher, 2000, d, device, seed=seed+10000)
+ s_train = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=256, shuffle=True)
+ s_test = DataLoader(TensorDataset(X_te, Y_te), batch_size=256, shuffle=False)
+
+ frozen_acc = train_student_frozen(seed, s_train, s_test, device, 80, 1e-3, 0.01, alpha)
+ shallow_acc = train_student_shallow(seed, s_train, s_test, device, 80, 1e-3, 0.01, alpha)
+ results[f'student_frozen_s{seed}'] = frozen_acc
+ results[f'student_shallow_s{seed}'] = shallow_acc
+ print(f" FINAL Student s={seed}: frozen={frozen_acc:.4f}, shallow={shallow_acc:.4f}", flush=True)
+
+ with open(args.output, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved: {args.output}", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/resnet_frozen_blocks_baseline.py b/experiments/resnet_frozen_blocks_baseline.py
new file mode 100644
index 0000000..787876d
--- /dev/null
+++ b/experiments/resnet_frozen_blocks_baseline.py
@@ -0,0 +1,278 @@
+"""
+Frozen-blocks and shallow baselines for a small CIFAR-10 ResNet (BatchNorm,
+no LayerNorm) — codex-round-10 control to test whether the DFA "active-harm"
+walk-back generalizes from LN-based architectures (ViT-Mini, ResMLP) to a
+BN-based residual architecture.
+
+Conditions per seed:
+ - BP shallow (num_blocks=0)
+ - BP frozen-blocks (num_blocks=4 frozen)
+ - BP trainable (num_blocks=4)
+ - DFA shallow (num_blocks=0)
+ - DFA frozen-blocks (num_blocks=4 frozen)
+ - DFA trainable (num_blocks=4)
+
+If DFA-trainable < DFA-shallow on ResNet too → claim becomes "FA fails to train
+deep blocks across multiple residual architectures including BN-based" — much
+harder to dismiss as LN-specific.
+If DFA-trainable ≈ or > DFA-shallow on ResNet → "harmful mode is specific to LN
+normalization or terminal-LN architectures" — narrower but still useful claim.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 python experiments/resnet_frozen_blocks_baseline.py --seed 42
+"""
+import sys, os, argparse
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import numpy as np
+
+from models.small_resnet import SmallResNet
+
+
+def get_loaders(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ n = c = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(dev), y.to(dev)
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def freeze_blocks(model):
+ for p in model.blocks.parameters():
+ p.requires_grad_(False)
+ # Also keep BN running stats frozen by setting to eval()
+ for m in model.blocks.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ m.eval()
+
+
+def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd, label, blocks_frozen=False):
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ if blocks_frozen:
+ for m in model.blocks.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ m.eval() # keep BN stats frozen
+ for x, y in train_loader:
+ x, y = x.to(dev), y.to(dev)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(model, test_loader, dev)
+ print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True)
+ return model
+
+
+def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label, blocks_frozen=False):
+ """DFA on the BN-ResNet:
+ - head trained with true CE on the pooled hidden state
+ - stem (conv + BN) trained via DFA-style local loss with random feedback
+ - blocks (if any) skipped (frozen for blocks_frozen=True; for trainable case, the
+ naive analog would be DFA-style local loss per block, but this script focuses on
+ the frozen/shallow comparison; for trainable comparison use the existing ResMLP
+ experiment as the analogous "trainable" since they share the same ad-hoc DFA pattern).
+ For this experiment we focus on the frozen and shallow conditions.
+ """
+ d_hidden = model.d_hidden
+ L = max(model.num_blocks, 1)
+ C = 10
+ Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)]
+
+ stem_params = list(model.stem_conv.parameters()) + list(model.stem_bn.parameters())
+ stem_opt = optim.AdamW(stem_params, lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
+ sch1 = optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=epochs)
+ sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)
+
+ for ep in range(1, epochs + 1):
+ model.train()
+ if blocks_frozen:
+ for m in model.blocks.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ m.eval()
+ for x, y in train_loader:
+ x, y = x.to(dev), y.to(dev)
+ with torch.no_grad():
+ logits, hi = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1
+ hL_det = hi[-1].detach() # (B, d_hidden, 32, 32)
+ # Head update via true CE on pooled cls
+ h_pool = F.adaptive_avg_pool2d(hL_det, 1).flatten(1)
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(h_pool), y).backward()
+ head_opt.step()
+ # Stem update via DFA local loss
+ a0 = (e_T @ Bs[0].T).detach() # (B, d_hidden)
+ rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.stem(x) # (B, d_hidden, 32, 32)
+ # Broadcast credit across spatial positions: (B, d, 1, 1) -> (B, d, H, W)
+ a0_b = (a0 / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h0)
+ stem_loss = (h0 * a0_b).sum(dim=1).mean() # average over batch and spatial
+ stem_opt.zero_grad()
+ stem_loss.backward()
+ stem_opt.step()
+ sch1.step(); sch2.step()
+ if ep % 10 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(model, test_loader, dev)
+ print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True)
+ return model
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--epochs', type=int, default=60)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--d_hidden', type=int, default=64)
+ args = parser.parse_args()
+
+ dev = torch.device('cuda:0')
+ print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True)
+ train_loader, test_loader = get_loaders(batch_size=128)
+
+ results = {}
+ C = 10
+
+ # Trainable BP (full 4-block ResNet)
+ print(f"\n=== BP trainable (SmallResNet num_blocks=4), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev)
+ print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
+ train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-trainable')
+ results['bp_trainable'] = evaluate(m, test_loader, dev)
+ print(f"FINAL BP-trainable: {results['bp_trainable']:.4f}", flush=True)
+
+ # Trainable DFA — block-level DFA on ResNet (each block as a unit)
+ print(f"\n=== DFA trainable (SmallResNet num_blocks=4 block-level DFA), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev)
+ # We use the same approach as ViT/ResMLP: stem trained with DFA, blocks trained
+ # with their own DFA-style local loss per block, head with true CE.
+ # For simplicity reuse train_dfa logic but extend it to also train blocks.
+ # Since this script focuses on frozen/shallow control, we'll do trainable in a
+ # separate inner loop here.
+ d_hidden = m.d_hidden; L = m.num_blocks
+ Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=args.wd) for b in m.blocks]
+ stem_params = list(m.stem_conv.parameters()) + list(m.stem_bn.parameters())
+ stem_opt = optim.AdamW(stem_params, lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(m.out_head.parameters(), lr=args.lr, weight_decay=args.wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]
+ for ep in range(1, args.epochs + 1):
+ m.train()
+ for x, y in train_loader:
+ x, y = x.to(dev), y.to(dev)
+ with torch.no_grad():
+ logits, hi = m(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1
+ hL_det = hi[-1].detach()
+ h_pool = F.adaptive_avg_pool2d(hL_det, 1).flatten(1)
+ head_opt.zero_grad()
+ F.cross_entropy(m.out_head(h_pool), y).backward()
+ head_opt.step()
+ for l in range(L):
+ h_l = hi[l].detach()
+ a_l = (e_T @ Bs[l].T).detach()
+ rms = (a_l ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_l_norm = (a_l / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h_l)
+ f_l = m.blocks[l](h_l)
+ local_loss = (f_l * a_l_norm).sum(dim=1).mean()
+ block_opts[l].zero_grad(); local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(m.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a_0 = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = m.stem(x)
+ a_0_b = (a_0 / rms_0).unsqueeze(-1).unsqueeze(-1).expand_as(h0)
+ stem_loss = (h0 * a_0_b).sum(dim=1).mean()
+ stem_opt.zero_grad(); stem_loss.backward(); stem_opt.step()
+ for s in all_sch: s.step()
+ if ep % 10 == 0 or ep == 1 or ep == args.epochs:
+ acc = evaluate(m, test_loader, dev)
+ print(f" [DFA-trainable] ep {ep}: test_acc={acc:.4f}", flush=True)
+ results['dfa_trainable'] = evaluate(m, test_loader, dev)
+ print(f"FINAL DFA-trainable: {results['dfa_trainable']:.4f}", flush=True)
+
+ # BP shallow
+ print(f"\n=== BP shallow (SmallResNet num_blocks=0), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=0).to(dev)
+ print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
+ train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-shallow')
+ results['bp_shallow'] = evaluate(m, test_loader, dev)
+ print(f"FINAL BP-shallow: {results['bp_shallow']:.4f}", flush=True)
+
+ # BP frozen-blocks
+ print(f"\n=== BP frozen-blocks (SmallResNet num_blocks=4 frozen), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev)
+ freeze_blocks(m)
+ print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
+ train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-frozen', blocks_frozen=True)
+ results['bp_frozen'] = evaluate(m, test_loader, dev)
+ print(f"FINAL BP-frozen-blocks: {results['bp_frozen']:.4f}", flush=True)
+
+ # DFA shallow
+ print(f"\n=== DFA shallow (SmallResNet num_blocks=0), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=0).to(dev)
+ train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow')
+ results['dfa_shallow'] = evaluate(m, test_loader, dev)
+ print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True)
+
+ # DFA frozen-blocks
+ print(f"\n=== DFA frozen-blocks (SmallResNet num_blocks=4 frozen), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev)
+ freeze_blocks(m)
+ train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen', blocks_frozen=True)
+ results['dfa_frozen'] = evaluate(m, test_loader, dev)
+ print(f"FINAL DFA-frozen-blocks: {results['dfa_frozen']:.4f}", flush=True)
+
+ print(f"\n=== Small ResNet (BatchNorm) frozen/shallow baseline summary, seed={args.seed} ===")
+ for k, v in results.items():
+ print(f" {k}: {v:.4f}")
+ print(f"\nKey gaps (DFA):")
+ if 'dfa_shallow' in results and 'dfa_trainable' in results:
+ print(f" DFA-shallow ({results['dfa_shallow']:.4f}) - DFA-trainable ({results['dfa_trainable']:.4f}) = {results['dfa_shallow']-results['dfa_trainable']:+.4f}")
+ if 'dfa_frozen' in results and 'dfa_trainable' in results:
+ print(f" DFA-frozen ({results['dfa_frozen']:.4f}) - DFA-trainable ({results['dfa_trainable']:.4f}) = {results['dfa_frozen']-results['dfa_trainable']:+.4f}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/resnet_protocol_validation.py b/experiments/resnet_protocol_validation.py
new file mode 100644
index 0000000..f107231
--- /dev/null
+++ b/experiments/resnet_protocol_validation.py
@@ -0,0 +1,343 @@
+"""
+Protocol validation on SmallResNet (BatchNorm, no LN) — BP/FA/DFA + frozen baseline.
+Block-level DFA/FA: credit broadcast across spatial positions, same local loss as ResMLP.
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.small_resnet import SmallResNet
+from metrics.credit_metrics import cosine_similarity_batch
+
+
+def get_data(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ c = n = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(dev), y.to(dev)
+ c += (model(x).argmax(-1) == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def compute_diagnostics(model, x_eval, y_eval, device, method_name, dfa_Bs=None, fa_Bs=None):
+ """Compute per-layer cosine, ||g_l||, ||h_l|| for SmallResNet."""
+ model.eval()
+ L = model.num_blocks
+ C = 10
+
+ # Hidden states
+ with torch.no_grad():
+ _, hiddens = model(x_eval, return_hidden=True)
+
+ # For ||h||: pool each hidden to (B, d) then take norm
+ hidden_norms = []
+ for h in hiddens:
+ h_pool = F.adaptive_avg_pool2d(h, 1).flatten(1) # (B, d)
+ hidden_norms.append(float(h_pool.norm(dim=-1).median().item()))
+
+ # BP grads via manual forward
+ h = model.stem(x_eval)
+ hs = [h.clone().requires_grad_(True)]
+ for block in model.blocks:
+ # Need to handle BN eval mode for frozen
+ hs.append(block(hs[-1]))
+ h_pool = F.adaptive_avg_pool2d(hs[-1], 1).flatten(1)
+ logits = model.out_head(h_pool)
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+
+ # ||g_l|| using pooled gradient
+ bp_grad_norms = []
+ for g in grads:
+ g_pool = F.adaptive_avg_pool2d(g, 1).flatten(1) # (B, d)
+ bp_grad_norms.append(float(g_pool.norm(dim=-1).median().item()))
+
+ # Per-layer cosine
+ with torch.no_grad():
+ e_T = logits.softmax(-1)
+ e_T[torch.arange(x_eval.size(0)), y_eval] -= 1
+
+ bp_cosine = []
+ d = model.d_hidden
+
+ if method_name == 'fa' and fa_Bs is not None:
+ # FA: sequential backward from exact pooled gradient
+ hL_pool_req = F.adaptive_avg_pool2d(hiddens[-1].detach(), 1).flatten(1).requires_grad_(True)
+ logits_fa = model.out_head(hL_pool_req)
+ loss_fa = F.cross_entropy(logits_fa, y_eval)
+ a_credit = torch.autograd.grad(loss_fa, hL_pool_req)[0].detach()
+
+ for l in range(L - 1, -1, -1):
+ # Compare pooled credit with pooled BP grad
+ g_pool = F.adaptive_avg_pool2d(grads[l], 1).flatten(1).detach()
+ bp_cosine.insert(0, cosine_similarity_batch(a_credit, g_pool))
+ a_credit = (a_credit @ fa_Bs[l]).detach()
+
+ elif method_name == 'dfa' and dfa_Bs is not None:
+ for l in range(L):
+ a_dfa = (e_T @ dfa_Bs[l].T).detach() # (B, d)
+ g_pool = F.adaptive_avg_pool2d(grads[l], 1).flatten(1).detach()
+ bp_cosine.append(cosine_similarity_batch(a_dfa, g_pool))
+
+ elif method_name == 'bp':
+ bp_cosine = [1.0] * L
+
+ model.train()
+ return {
+ 'bp_cosine': bp_cosine,
+ 'bp_grad_norms_per_layer': bp_grad_norms,
+ 'hidden_norms_per_layer': hidden_norms,
+ }
+
+
+def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd):
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+ for ep in range(1, epochs + 1):
+ model.train()
+ tl, tc, tn = 0, 0, 0
+ for x, y in train_loader:
+ x, y = x.to(dev), y.to(dev)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ tl += loss.item() * x.size(0)
+ tc += (logits.argmax(1) == y).sum().item()
+ tn += x.size(0)
+ sch.step()
+ log['train_loss'].append(tl / tn)
+ log['train_acc'].append(tc / tn)
+ log['test_acc'].append(evaluate(model, test_loader, dev))
+ if ep % 10 == 0 or ep == epochs:
+ print(f" [BP] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True)
+ return log
+
+
+def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd):
+ d = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ Bs = [torch.randn(d, C, device=dev) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks]
+ stem_opt = optim.AdamW(list(model.stem_conv.parameters()) + list(model.stem_bn.parameters()),
+ lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+ for ep in range(1, epochs + 1):
+ model.train()
+ tl, tc, tn = 0, 0, 0
+ for x, y in train_loader:
+ x, y = x.to(dev), y.to(dev)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ # Head
+ hL_pool = F.adaptive_avg_pool2d(hiddens[-1].detach(), 1).flatten(1)
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(hL_pool), y).backward()
+ head_opt.step()
+ # Blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach() # (B, d)
+ rms = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_norm = (a_dfa / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h_l)
+ f_l = model.blocks[l](h_l) - h_l # residual output only
+ local_loss = (f_l * a_norm).sum(dim=1).mean()
+ block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step()
+ # Stem
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.stem(x)
+ a0_b = (a0 / rms0).unsqueeze(-1).unsqueeze(-1).expand_as(h0)
+ stem_opt.zero_grad()
+ (h0 * a0_b).sum(dim=1).mean().backward()
+ stem_opt.step()
+ for s in all_sch: s.step()
+ tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch
+ log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn)
+ log['test_acc'].append(evaluate(model, test_loader, dev))
+ if ep % 10 == 0 or ep == epochs:
+ print(f" [DFA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True)
+ return log, Bs
+
+
+def train_fa(model, train_loader, test_loader, dev, epochs, lr, wd):
+ d = model.d_hidden
+ L = model.num_blocks
+ Bs = [torch.randn(d, d, device=dev) / np.sqrt(d) for _ in range(L)]
+ block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks]
+ stem_opt = optim.AdamW(list(model.stem_conv.parameters()) + list(model.stem_bn.parameters()),
+ lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+ for ep in range(1, epochs + 1):
+ model.train()
+ tl, tc, tn = 0, 0, 0
+ for x, y in train_loader:
+ x, y = x.to(dev), y.to(dev)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ # Head — get gradient BEFORE step
+ hL_pool = F.adaptive_avg_pool2d(hiddens[-1].detach(), 1).flatten(1).requires_grad_(True)
+ logits_out = model.out_head(hL_pool)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ a_credit = hL_pool.grad.detach() # (B, d) — pooled gradient
+ head_opt.step()
+ # Top-down block updates with FA credit
+ for l in range(L - 1, -1, -1):
+ h_l = hiddens[l].detach()
+ rms = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_norm = (a_credit / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h_l)
+ f_l = model.blocks[l](h_l) - h_l
+ local_loss = (f_l * a_norm).sum(dim=1).mean()
+ block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step()
+ a_credit = (a_credit @ Bs[l]).detach()
+ # Stem
+ rms0 = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.stem(x)
+ a0_b = (a_credit / rms0).unsqueeze(-1).unsqueeze(-1).expand_as(h0)
+ stem_opt.zero_grad()
+ (h0 * a0_b).sum(dim=1).mean().backward()
+ stem_opt.step()
+ for s in all_sch: s.step()
+ tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch
+ log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn)
+ log['test_acc'].append(evaluate(model, test_loader, dev))
+ if ep % 10 == 0 or ep == epochs:
+ print(f" [FA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True)
+ return log, Bs
+
+
+def freeze_blocks(model):
+ for p in model.blocks.parameters():
+ p.requires_grad_(False)
+ for m in model.blocks.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ m.eval()
+
+
+def train_frozen(model, train_loader, test_loader, dev, epochs, lr, wd):
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for m in model.blocks.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ m.eval()
+ for x, y in train_loader:
+ x, y = x.to(dev), y.to(dev)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == epochs:
+ acc = evaluate(model, test_loader, dev)
+ print(f" [Frozen] ep {ep}: acc={acc:.4f}", flush=True)
+ return evaluate(model, test_loader, dev)
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--output', type=str, default='results/resnet_protocol_validation.json')
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--d_hidden', type=int, default=64)
+ args = p.parse_args()
+
+ dev = torch.device('cuda:0')
+ train_loader, test_loader = get_data(128)
+
+ # Eval buffer for diagnostics (128 samples, consistent with cifar_resmlp.py)
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 128: break
+ x_eval = torch.cat(xs)[:128].to(dev)
+ y_eval = torch.cat(ys)[:128].to(dev)
+
+ results = {}
+
+ for seed in [42, 123, 456]:
+ print(f"\n{'='*60}\nSeed {seed}\n{'='*60}", flush=True)
+ seed_results = {}
+
+ # BP
+ print("\n--- BP ---", flush=True)
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = SmallResNet(args.d_hidden, 10, 4).to(dev)
+ bp_log = train_bp(model, train_loader, test_loader, dev, args.epochs, 1e-3, 0.01)
+ bp_diag = compute_diagnostics(model, x_eval, y_eval, dev, 'bp')
+ seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag}
+
+ # FA
+ print("\n--- FA ---", flush=True)
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = SmallResNet(args.d_hidden, 10, 4).to(dev)
+ fa_log, fa_Bs = train_fa(model, train_loader, test_loader, dev, args.epochs, 1e-3, 0.01)
+ fa_diag = compute_diagnostics(model, x_eval, y_eval, dev, 'fa', fa_Bs=fa_Bs)
+ seed_results['fa'] = {'log': fa_log, 'diagnostics': fa_diag}
+
+ # DFA
+ print("\n--- DFA ---", flush=True)
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = SmallResNet(args.d_hidden, 10, 4).to(dev)
+ dfa_log, dfa_Bs = train_dfa(model, train_loader, test_loader, dev, args.epochs, 1e-3, 0.01)
+ dfa_diag = compute_diagnostics(model, x_eval, y_eval, dev, 'dfa', dfa_Bs=dfa_Bs)
+ seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag}
+
+ # Frozen baseline
+ print("\n--- Frozen ---", flush=True)
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = SmallResNet(args.d_hidden, 10, 4).to(dev)
+ freeze_blocks(model)
+ frozen_acc = train_frozen(model, train_loader, test_loader, dev, args.epochs, 1e-3, 0.01)
+ seed_results['frozen_acc'] = frozen_acc
+ print(f"FINAL frozen: {frozen_acc:.4f}", flush=True)
+
+ results[str(seed)] = seed_results
+
+ with open(args.output, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved: {args.output}", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/snapshot_compare_outln.py b/experiments/snapshot_compare_outln.py
new file mode 100644
index 0000000..9b0ac7c
--- /dev/null
+++ b/experiments/snapshot_compare_outln.py
@@ -0,0 +1,93 @@
+"""
+Compare snapshot evolution JSONs across with-out_ln vs no-out_ln conditions
+and across seeds. Produces summary tables for the P4 figure.
+
+Usage:
+ python experiments/snapshot_compare_outln.py
+"""
+import os, sys, json, glob
+import numpy as np
+
+
+def load(path):
+ if not os.path.exists(path):
+ return None
+ with open(path) as f:
+ return json.load(f)
+
+
+def field(d, key, layer=2):
+ """Extract a per-epoch list of values for a given metric/layer."""
+ if d is None:
+ return None
+ log = d['bp_log'] if 'bp' in key else d['dfa_log'] if 'dfa' in key else None
+ metric = key.replace('bp_', '').replace('dfa_', '')
+ if metric == 'h_L_norm':
+ return [r['hidden_norms'][-1] for r in log]
+ if metric == 'h_L2_norm':
+ return [r['hidden_norms'][2] if len(r['hidden_norms']) > 2 else None for r in log]
+ if metric == 'g_l2':
+ key_in_log = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in log[0] else 'bp_grad_norms_per_sample_med'
+ return [r[key_in_log][layer] for r in log]
+ if metric == 'acc':
+ return [r['acc_eval'] for r in log]
+ if metric == 'gamma_dfa':
+ return [r.get('gamma_dfa', float('nan')) for r in log]
+ return None
+
+
+def summary_row(d, label):
+ """Print a summary row for the comparison table."""
+ if d is None:
+ print(f"{label:35s} MISSING")
+ return
+ bp = d['bp_log']
+ dfa = d['dfa_log']
+ bp_eps = [r['epoch'] for r in bp]
+ dfa_eps = [r['epoch'] for r in dfa]
+ bp_h_L_init = bp[0]['hidden_norms'][-1]
+ bp_h_L_final = bp[-1]['hidden_norms'][-1]
+ dfa_h_L_init = dfa[0]['hidden_norms'][-1]
+ dfa_h_L_final = dfa[-1]['hidden_norms'][-1]
+ bp_g_key = 'bp_grad_per_sample_l2_med' if 'bp_grad_per_sample_l2_med' in bp[0] else 'bp_grad_norms_per_sample_med'
+ bp_g2_init = bp[0][bp_g_key][2]
+ bp_g2_final = bp[-1][bp_g_key][2]
+ dfa_g2_init = dfa[0][bp_g_key][2]
+ dfa_g2_final = dfa[-1][bp_g_key][2]
+ bp_acc = bp[-1]['acc_eval']
+ dfa_acc = dfa[-1]['acc_eval']
+ bp_growth = bp_h_L_final / max(bp_h_L_init, 1e-12)
+ dfa_growth = dfa_h_L_final / max(dfa_h_L_init, 1e-12)
+ bp_g_change = bp_g2_final / max(bp_g2_init, 1e-30)
+ dfa_g_change = dfa_g2_final / max(dfa_g2_init, 1e-30)
+
+ print(f"{label:35s} BP_acc={bp_acc:.3f} DFA_acc={dfa_acc:.3f} "
+ f"BP_||h_L||: {bp_h_L_init:.1e}→{bp_h_L_final:.1e} (×{bp_growth:.1e}) "
+ f"DFA_||h_L||: {dfa_h_L_init:.1e}→{dfa_h_L_final:.1e} (×{dfa_growth:.1e}) "
+ f"BP_||g_2||: {bp_g2_init:.1e}→{bp_g2_final:.1e} "
+ f"DFA_||g_2||: {dfa_g2_init:.1e}→{dfa_g2_final:.1e}")
+
+
+def main():
+ print("=" * 130)
+ print("SNAPSHOT EVOLUTION COMPARISON: with-out_ln vs no-out_ln vs synthetic")
+ print("=" * 130)
+
+ runs = [
+ ('with-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_evolution_v2/snapshot_evolution_s42.json'),
+ ('no-out_ln s42 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s42.json'),
+ ('no-out_ln s123 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s123.json'),
+ ('no-out_ln s456 (ResMLP CIFAR)', 'results/snapshot_no_outln_v1/snapshot_noLN_s456.json'),
+ ('synthetic α=1 s42 (StudentNet)', 'results/snapshot_synth_v1/snapshot_synth_a1.0_L4_s42.json'),
+ ]
+ for label, path in runs:
+ d = load(path)
+ summary_row(d, label)
+
+ print()
+ print("Legend: ||h_L|| = median per-sample L2 norm of final hidden state; ||g_2|| = median per-sample L2 norm of BP gradient at h_2.")
+ print("All norms use .norm(dim=-1), correct.")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/snapshot_evolution_no_outln.py b/experiments/snapshot_evolution_no_outln.py
new file mode 100644
index 0000000..312a4cb
--- /dev/null
+++ b/experiments/snapshot_evolution_no_outln.py
@@ -0,0 +1,249 @@
+"""
+Snapshot evolution on a NO-out_ln variant of the standard ResidualMLP.
+Same architecture as ResidualMLP but with the terminal LayerNorm removed
+(head reads h_L directly). Trains BP and DFA from scratch on CIFAR-10 and
+logs ||h_l||_2 + ||BP grad||_2 per epoch.
+
+This is the architectural causal control for P4: if removing out_ln from the
+SAME architecture rescues the residual-stream pathology, then out_ln is
+causally responsible (not just correlated).
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 nohup python experiments/snapshot_evolution_no_outln.py \
+ --output_dir results/snapshot_no_outln_v1 --epochs 100 --seed 42 \
+ > results/snapshot_no_outln_v1/run_s42.log 2>&1 &
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from metrics.credit_metrics import cosine_similarity_batch
+
+
+class ResidualBlockPreLN(nn.Module):
+ """Same as models/residual_mlp.ResidualBlock — pre-LN MLP block."""
+ def __init__(self, d_hidden: int):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.w1 = nn.Linear(d_hidden, d_hidden)
+ self.w2 = nn.Linear(d_hidden, d_hidden)
+ nn.init.normal_(self.w2.weight, std=0.01)
+ nn.init.zeros_(self.w2.bias)
+ def forward(self, h):
+ z = self.ln(h)
+ z = self.w1(z)
+ z = F.gelu(z)
+ z = self.w2(z)
+ return z
+
+
+class ResidualMLP_NoOutLN(nn.Module):
+ """Like ResidualMLP, but WITHOUT out_ln. Head reads h_L directly."""
+ def __init__(self, input_dim, d_hidden, num_classes, num_blocks):
+ super().__init__()
+ self.embed = nn.Linear(input_dim, d_hidden)
+ self.blocks = nn.ModuleList([ResidualBlockPreLN(d_hidden) for _ in range(num_blocks)])
+ # NO out_ln
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.num_blocks = num_blocks
+ self.d_hidden = d_hidden
+
+ def forward(self, x, return_hidden=False):
+ h = self.embed(x)
+ hiddens = [h] if return_hidden else None
+ for block in self.blocks:
+ f = block(h)
+ h = h + f
+ if return_hidden:
+ hiddens.append(h)
+ logits = self.out_head(h) # NO out_ln
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+
+def get_cifar10(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))
+
+
+def fixed_eval_buffer(test_loader, device, n_samples=1024):
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= n_samples:
+ break
+ return torch.cat(xs)[:n_samples].to(device), torch.cat(ys)[:n_samples].to(device)
+
+
+def diagnose(model, x_eval, y_eval, dfa_Bs=None):
+ was_training = model.training
+ model.eval()
+ L = model.num_blocks
+ with torch.no_grad():
+ _, hi = model(x_eval, return_hidden=True)
+ hidden_norms = [h.norm(dim=-1).median().item() for h in hi]
+
+ h0 = model.embed(x_eval.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ logits = model.out_head(hs[-1]) # NO out_ln
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ bp_l2 = [g.norm(dim=-1).median().item() for g in grads]
+ bp_full = [g.detach() for g in grads]
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ loss_val = loss.item()
+
+ gamma_dfa = float('nan'); per_layer_gamma = []
+ if dfa_Bs is not None:
+ with torch.no_grad():
+ e_T = logits.softmax(-1); e_T[torch.arange(x_eval.size(0)), y_eval] -= 1
+ for l in range(L):
+ a_dfa = (e_T @ dfa_Bs[l].T).detach()
+ per_layer_gamma.append(cosine_similarity_batch(a_dfa, bp_full[l]))
+ gamma_dfa = float(np.mean(per_layer_gamma))
+
+ if was_training: model.train()
+ return {
+ 'hidden_norms': hidden_norms,
+ 'bp_grad_per_sample_l2_med': bp_l2,
+ 'gamma_dfa': gamma_dfa,
+ 'gamma_dfa_per_layer': per_layer_gamma,
+ 'acc_eval': acc, 'loss_eval': loss_val,
+ }
+
+
+def train_bp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd):
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ log = []
+ d0 = diagnose(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0)
+ print(f" [BP-noLN] Ep 0: ||h_L||={d0['hidden_norms'][-1]:.3e} ||g||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ logits = model(x); loss = F.cross_entropy(logits, y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ d = diagnose(model, x_eval, y_eval); d['epoch'] = ep; log.append(d)
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ print(f" [BP-noLN] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} ||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f}", flush=True)
+ return log
+
+
+def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd):
+ d_hidden = model.d_hidden; L = model.num_blocks; C = 10
+ Bs = [torch.randn(d_hidden, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = []
+ d0 = diagnose(model, x_eval, y_eval, dfa_Bs=Bs); d0['epoch'] = 0; log.append(d0)
+ print(f" [DFA-noLN] Ep 0: ||h_L||={d0['hidden_norms'][-1]:.3e} ||g||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ # Head update — NO out_ln
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ # Block updates
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_dfa / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ # Embed update
+ a_0 = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ d = diagnose(model, x_eval, y_eval, dfa_Bs=Bs); d['epoch'] = ep; log.append(d)
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ print(f" [DFA-noLN] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} ||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f} γ={d['gamma_dfa']:.4f}", flush=True)
+ return log
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--output_dir', type=str, default='results/snapshot_no_outln_v1')
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--wd', type=float, default=0.01)
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--depth', type=int, default=4)
+ p.add_argument('--d_hidden', type=int, default=256)
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device('cuda:0')
+ print(f"NO-OUT_LN VARIANT: depth={args.depth}, d_hidden={args.d_hidden}, "
+ f"epochs={args.epochs}, seed={args.seed}", flush=True)
+
+ train_loader, test_loader = get_cifar10(batch_size=128)
+ x_eval, y_eval = fixed_eval_buffer(test_loader, device, n_samples=1024)
+
+ L, d, C = args.depth, args.d_hidden, 10
+
+ print("\n=== BP training (NO out_ln) ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ bp_model = ResidualMLP_NoOutLN(3072, d, C, L).to(device)
+ bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd)
+
+ print("\n=== DFA training (NO out_ln) ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ dfa_model = ResidualMLP_NoOutLN(3072, d, C, L).to(device)
+ dfa_log = train_dfa(dfa_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd)
+
+ out = {
+ 'config': vars(args), 'depth': L, 'd_hidden': d, 'num_classes': C,
+ 'architecture': 'ResidualMLP_NoOutLN',
+ 'bp_log': bp_log, 'dfa_log': dfa_log,
+ }
+ out_path = os.path.join(args.output_dir, f'snapshot_noLN_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(out, f, indent=2)
+ print(f"\nSaved {out_path}", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/snapshot_evolution_residual_explosion.py b/experiments/snapshot_evolution_residual_explosion.py
index 1dc09f2..6155d94 100644
--- a/experiments/snapshot_evolution_residual_explosion.py
+++ b/experiments/snapshot_evolution_residual_explosion.py
@@ -212,6 +212,73 @@ def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_e
return log
+def train_fa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1):
+ """FA (Lillicrap 2016): sequential backward credit with d×d random matrices.
+ Canonical implementation matching cifar_resmlp.py train_fa():
+ - mean reduction (default)
+ - gradient taken BEFORE head step (old head weights)
+ - top-down block update, credit propagated after each block
+ - NO grad clipping
+ """
+ d_hidden = model.d_hidden
+ L = model.num_blocks
+ Bs_fa = [torch.randn(d_hidden, d_hidden, device=device) / np.sqrt(d_hidden) for _ in range(L)]
+ block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)])
+ log = []
+ d0 = diagnose(model, x_eval, y_eval)
+ d0['epoch'] = 0
+ log.append(d0)
+ print(f" [FA] Ep 0: ||h||_med={d0['hidden_norms']} acc={d0['acc_eval']:.4f}", flush=True)
+ for epoch in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ # Forward
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ # Head update — get gradient BEFORE step (old head weights)
+ hL_det = hiddens[-1].detach().requires_grad_(True)
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y) # mean reduction
+ head_opt.zero_grad()
+ loss_out.backward()
+ a_credit = hL_det.grad.detach() # gradient w.r.t. old head
+ head_opt.step()
+ # Top-down block updates with sequential FA credit propagation
+ for l in range(L - 1, -1, -1):
+ h_l = hiddens[l].detach()
+ rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_credit / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step() # no grad clipping
+ a_credit = (a_credit @ Bs_fa[l]).detach()
+ # Embed update with final propagated credit
+ rms_0 = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_credit / rms_0)).sum(dim=-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch:
+ s.step()
+ if epoch % log_every == 0 or epoch == epochs:
+ d = diagnose(model, x_eval, y_eval)
+ d['epoch'] = epoch
+ log.append(d)
+ print(f" [FA] Ep {epoch}: ||h_L||={d['hidden_norms'][-1]:.3e} "
+ f"||g_2||={d['bp_grad_norms_per_sample_med'][2]:.3e} "
+ f"acc={d['acc_eval']:.4f}", flush=True)
+ return log
+
+
def main():
p = argparse.ArgumentParser()
p.add_argument('--output_dir', type=str, default='results/snapshot_evolution_v2')
@@ -262,11 +329,22 @@ def main():
args.epochs, args.lr, args.wd, log_every=args.log_every,
random_targets=args.random_targets)
+ fa_log = None
+ if not args.skip_bp and not args.random_targets: # FA only when doing full run
+ print("\n=== FA training ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ fa_model = ResidualMLP(3072, d, C, L,
+ residual_add=not args.no_residual_add,
+ w2_std=args.w2_std).to(device)
+ fa_log = train_fa(fa_model, train_loader, x_eval, y_eval, device,
+ args.epochs, args.lr, args.wd, log_every=args.log_every)
+
out = {
'config': vars(args),
'depth': L, 'd_hidden': d, 'num_classes': C,
'bp_log': bp_log,
'dfa_log': dfa_log,
+ 'fa_log': fa_log,
}
out_path = os.path.join(args.output_dir, f'snapshot_evolution_s{args.seed}.json')
with open(out_path, 'w') as f:
diff --git a/experiments/snapshot_evolution_vit.py b/experiments/snapshot_evolution_vit.py
new file mode 100644
index 0000000..ce4c090
--- /dev/null
+++ b/experiments/snapshot_evolution_vit.py
@@ -0,0 +1,244 @@
+"""
+Snapshot evolution on a ViT-Mini (modern transformer-style architecture) trained
+with BP and block-level DFA on CIFAR-10. Logs ||h_l||, ||BP grad||, Γ per epoch.
+
+This is the P4 generalization test: does the residual-stream pathology + LayerNorm
+gradient collapse mechanism (verified on pre-LN ResMLP with terminal LN) also
+appear on an actual transformer architecture? If yes → strong P4 in modern setting.
+
+Block-level DFA: each TransformerBlock is a "layer". The DFA credit
+`a_l = e_T @ B_l^T` is broadcast across all tokens at that block's input. The
+local block loss is `<block_l(h_l), broadcast(a_l)>` summed over tokens.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 nohup python experiments/snapshot_evolution_vit.py \
+ --output_dir results/snapshot_vit_v1 --epochs 60 --seed 42 \
+ > results/snapshot_vit_v1/run_s42.log 2>&1 &
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.vit_mini import ViTMini, TransformerBlock
+from metrics.credit_metrics import cosine_similarity_batch
+
+
+def get_cifar10(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))
+
+
+def fixed_eval_buffer(test_loader, device, n_samples=1024):
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= n_samples:
+ break
+ return torch.cat(xs)[:n_samples].to(device), torch.cat(ys)[:n_samples].to(device)
+
+
+def diagnose(model, x_eval, y_eval, dfa_Bs=None):
+ """Compute per-block ||h_l|| and ||BP grad at h_l||, plus optional Γ vs DFA credit."""
+ was_training = model.training
+ model.eval()
+ L = model.num_blocks
+
+ # Hidden states (no grad)
+ with torch.no_grad():
+ _, hiddens = model(x_eval, return_hidden=True)
+ # hiddens[l] is shape (B, n_tokens, d_model)
+ # Reduce to per-sample by taking the cls-token norm OR by flattening across tokens
+ # We'll report cls-token norm (the one that actually flows to the head)
+ hidden_norms_cls = [h[:, 0].norm(dim=-1).median().item() for h in hiddens]
+ hidden_norms_avg = [h.norm(dim=-1).mean().item() for h in hiddens] # avg across tokens then over batch
+
+ # BP gradients
+ h0 = model.embed(x_eval.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(b(hs[-1]))
+ h_cls = model.out_ln(hs[-1][:, 0])
+ logits = model.out_head(h_cls)
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ # grads[l] is shape (B, n_tokens, d_model)
+ # Per-sample L2 norm: take Frobenius over tokens × d_model
+ bp_grad_per_sample_l2 = [g.flatten(1).norm(dim=-1).median().item() for g in grads]
+ bp_grad_F = [g.norm().item() for g in grads]
+ bp_full = [g.detach() for g in grads]
+
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ loss_val = loss.item()
+
+ gamma_dfa = float('nan'); per_layer_gamma = []
+ if dfa_Bs is not None:
+ with torch.no_grad():
+ e_T = logits.softmax(-1); e_T[torch.arange(x_eval.size(0)), y_eval] -= 1
+ for l in range(L):
+ # Block-level DFA credit: per-sample (B, d_model), broadcast to (B, n_tokens, d_model)
+ a_dfa_per_sample = (e_T @ dfa_Bs[l].T).detach() # (B, d_model)
+ a_dfa_broadcast = a_dfa_per_sample.unsqueeze(1).expand_as(bp_full[l]) # (B, n_tokens, d_model)
+ # Cosine using flattened (per-sample) representation
+ per_layer_gamma.append(cosine_similarity_batch(
+ a_dfa_broadcast.flatten(1), bp_full[l].flatten(1)))
+ gamma_dfa = float(np.mean(per_layer_gamma))
+
+ if was_training:
+ model.train()
+
+ return {
+ 'hidden_norms_cls': hidden_norms_cls,
+ 'hidden_norms_avg': hidden_norms_avg,
+ 'bp_grad_per_sample_l2_med': bp_grad_per_sample_l2,
+ 'bp_grad_F': bp_grad_F,
+ 'gamma_dfa': gamma_dfa,
+ 'gamma_dfa_per_layer': per_layer_gamma,
+ 'acc_eval': acc,
+ 'loss_eval': loss_val,
+ }
+
+
+def train_bp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd):
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ log = []
+ d0 = diagnose(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0)
+ print(f" [BP-vit] Ep 0: ||h_L_cls||={d0['hidden_norms_cls'][-1]:.3e} ||g_2||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ logits = model(x); loss = F.cross_entropy(logits, y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ d = diagnose(model, x_eval, y_eval); d['epoch'] = ep; log.append(d)
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ print(f" [BP-vit] Ep {ep}: ||h_L_cls||={d['hidden_norms_cls'][-1]:.3e} ||g_2||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f}", flush=True)
+ return log
+
+
+def train_dfa_block_level(model, train_loader, x_eval, y_eval, device, epochs, lr, wd):
+ """Block-level DFA on ViT. Each TransformerBlock is treated as a unit; DFA credit
+ is broadcast across all tokens at the block's input.
+ """
+ d_model = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ Bs = [torch.randn(d_model, C, device=device) / np.sqrt(C) for _ in range(L)]
+
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(
+ list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed],
+ lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = []
+ d0 = diagnose(model, x_eval, y_eval, dfa_Bs=Bs); d0['epoch'] = 0; log.append(d0)
+ print(f" [DFA-vit] Ep 0: ||h_L_cls||={d0['hidden_norms_cls'][-1]:.3e} ||g_2||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ # Head update via direct CE on cls token
+ h_cls = model.out_ln(hL_det[:, 0])
+ logits_out = model.out_head(h_cls)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ # Block updates: each block's local loss = <block(h_l), a_dfa_broadcast>
+ for l in range(L):
+ h_l = hiddens[l].detach() # (B, n_tokens, d)
+ a_dfa = (e_T @ Bs[l].T).detach() # (B, d)
+ a_dfa_broadcast = a_dfa.unsqueeze(1).expand_as(h_l) # (B, n_tokens, d)
+ rms = (a_dfa_broadcast ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_dfa_broadcast / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ # Embed update (patch embed + cls + pos)
+ a_0 = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x) # (B, n_tokens, d)
+ a_0_broadcast = a_0.unsqueeze(1).expand_as(h0)
+ embed_loss = (h0 * (a_0_broadcast / rms_0.unsqueeze(1))).sum(dim=-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ d = diagnose(model, x_eval, y_eval, dfa_Bs=Bs); d['epoch'] = ep; log.append(d)
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ print(f" [DFA-vit] Ep {ep}: ||h_L_cls||={d['hidden_norms_cls'][-1]:.3e} ||g_2||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f} γ={d['gamma_dfa']:.4f}", flush=True)
+ return log
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--output_dir', type=str, default='results/snapshot_vit_v1')
+ p.add_argument('--epochs', type=int, default=60)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--wd', type=float, default=0.05)
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--depth', type=int, default=4)
+ p.add_argument('--d_model', type=int, default=128)
+ p.add_argument('--n_heads', type=int, default=4)
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device('cuda:0')
+ print(f"ViT-MINI: depth={args.depth}, d_model={args.d_model}, n_heads={args.n_heads}, "
+ f"epochs={args.epochs}, seed={args.seed}", flush=True)
+
+ train_loader, test_loader = get_cifar10(batch_size=128)
+ x_eval, y_eval = fixed_eval_buffer(test_loader, device, n_samples=1024)
+
+ print("\n=== BP training (ViT-Mini) ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ bp_model = ViTMini(num_blocks=args.depth, d_model=args.d_model, n_heads=args.n_heads).to(device)
+ print(f" n_params={sum(p.numel() for p in bp_model.parameters())}", flush=True)
+ bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd)
+
+ print("\n=== DFA training (ViT-Mini, block-level DFA) ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ dfa_model = ViTMini(num_blocks=args.depth, d_model=args.d_model, n_heads=args.n_heads).to(device)
+ dfa_log = train_dfa_block_level(dfa_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd)
+
+ out = {
+ 'config': vars(args), 'depth': args.depth, 'd_model': args.d_model,
+ 'architecture': 'ViTMini', 'bp_log': bp_log, 'dfa_log': dfa_log,
+ }
+ out_path = os.path.join(args.output_dir, f'snapshot_vit_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(out, f, indent=2)
+ print(f"\nSaved {out_path}", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/snapshot_fa_crossarch.py b/experiments/snapshot_fa_crossarch.py
new file mode 100644
index 0000000..8fa9e71
--- /dev/null
+++ b/experiments/snapshot_fa_crossarch.py
@@ -0,0 +1,243 @@
+"""
+FA-only snapshot evolution for ViT-Mini and ResMLP-no-outLN.
+Produces per-epoch ||h_L||, ||g_L||, acc for FA training.
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+from models.vit_mini import ViTMini
+
+
+def get_cifar10(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))
+
+
+def fixed_eval_buffer(loader, device, n=1024):
+ xs, ys = [], []
+ for x, y in loader:
+ xs.append(x); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= n:
+ break
+ return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device)
+
+
+# ─── Diagnose (works for both ViT and ResMLP) ───────────────────────────
+
+def diagnose_resmlp(model, x_eval, y_eval):
+ model.eval()
+ x_flat = x_eval.view(x_eval.size(0), -1)
+ with torch.no_grad():
+ _, hiddens = model(x_flat, return_hidden=True)
+ hidden_norms = [h.norm(dim=-1).median().item() for h in hiddens]
+ # BP grads
+ h0 = model.embed(x_flat.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ # Handle both with and without out_ln
+ if hasattr(model, 'out_ln'):
+ logits = model.out_head(model.out_ln(hs[-1]))
+ else:
+ logits = model.out_head(hs[-1])
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ g_norms = [g.norm(dim=-1).median().item() for g in grads]
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ model.train()
+ return {'hidden_norms': hidden_norms, 'bp_grad_norms_per_sample_med': g_norms, 'acc_eval': acc}
+
+
+def diagnose_vit(model, x_eval, y_eval):
+ model.eval()
+ with torch.no_grad():
+ _, hiddens = model(x_eval, return_hidden=True)
+ h_cls_norms = [h[:, 0].norm(dim=-1).median().item() for h in hiddens]
+ # BP grads via manual forward
+ h0 = model.embed(x_eval.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ h_cls = model.out_ln(hs[-1][:, 0])
+ logits = model.out_head(h_cls)
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ g_cls_norms = [g[:, 0].norm(dim=-1).median().item() for g in grads]
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ model.train()
+ return {'hidden_norms_cls': h_cls_norms, 'bp_grad_per_sample_l2_med': g_cls_norms, 'acc_eval': acc}
+
+
+# ─── FA training ─────────────────────────────────────────────────────────
+
+def train_fa_resmlp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, no_outln=False):
+ d_hidden = model.d_hidden
+ L = model.num_blocks
+ Bs = [torch.randn(d_hidden, d_hidden, device=device) / np.sqrt(d_hidden) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_params = list(model.out_head.parameters())
+ if hasattr(model, 'out_ln') and model.out_ln is not None:
+ head_params += list(model.out_ln.parameters())
+ head_opt = optim.AdamW(head_params, lr=lr, weight_decay=wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = []
+ d0 = diagnose_resmlp(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0)
+ print(f" [FA] Ep 0: acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL_det)) if hasattr(model, 'out_ln') else model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ # FA credits
+ hL_req = hiddens[-1].detach().requires_grad_(True)
+ logits_fa = model.out_head(model.out_ln(hL_req)) if hasattr(model, 'out_ln') else model.out_head(hL_req)
+ loss_fa = F.cross_entropy(logits_fa, y, reduction='sum')
+ a_L = torch.autograd.grad(loss_fa, hL_req)[0].detach()
+ credits = [None] * L
+ credits[L-1] = a_L
+ for ll in range(L-2, -1, -1):
+ credits[ll] = (credits[ll+1] @ Bs[ll+1]).detach()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_l = credits[l]
+ rms = (a_l**2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a_l / rms)).sum(dim=-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a_0 = credits[0]
+ rms_0 = (a_0**2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ d = diagnose_resmlp(model, x_eval, y_eval); d['epoch'] = ep; log.append(d)
+ if ep % 10 == 0 or ep == 1 or ep == epochs:
+ print(f" [FA] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} "
+ f"||g_L||={d['bp_grad_norms_per_sample_med'][-1]:.3e} "
+ f"acc={d['acc_eval']:.4f}", flush=True)
+ return log
+
+
+def train_fa_vit(model, train_loader, x_eval, y_eval, device, epochs, lr, wd):
+ """Canonical FA for ViT: mean reduction, grad before step, no clipping, top-down."""
+ d_model = model.d_hidden
+ L = model.num_blocks
+ Bs = [torch.randn(d_model, d_model, device=device) / np.sqrt(d_model) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(
+ list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed],
+ lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = []
+ d0 = diagnose_vit(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0)
+ print(f" [FA-vit] Ep 0: acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ # Head update — grad BEFORE step (old head)
+ hL_det = hiddens[-1].detach().requires_grad_(True)
+ h_cls = model.out_ln(hL_det[:, 0])
+ logits_out = model.out_head(h_cls)
+ loss_out = F.cross_entropy(logits_out, y) # mean reduction
+ head_opt.zero_grad()
+ loss_out.backward()
+ a_L_full = hL_det.grad.detach() # (B, n_tokens, d)
+ head_opt.step()
+ # Use mean over tokens for the backward signal
+ a_credit = a_L_full.mean(dim=1) # (B, d)
+ # Top-down block updates, propagate credit after each
+ for l in range(L - 1, -1, -1):
+ h_l = hiddens[l].detach()
+ a_broadcast = a_credit.unsqueeze(1).expand_as(h_l)
+ rms = (a_broadcast ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a_broadcast / rms)).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step() # no clipping
+ a_credit = (a_credit @ Bs[l]).detach()
+ # Embed update with final propagated credit
+ a_0_broadcast = a_credit.unsqueeze(1)
+ rms_0 = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_0_broadcast / rms_0.unsqueeze(1))).sum(dim=-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ d = diagnose_vit(model, x_eval, y_eval); d['epoch'] = ep; log.append(d)
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ print(f" [FA-vit] Ep {ep}: ||h_L||={d['hidden_norms_cls'][-1]:.3e} "
+ f"||g_L||={d['bp_grad_per_sample_l2_med'][-1]:.3e} "
+ f"acc={d['acc_eval']:.4f}", flush=True)
+ return log
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--arch', choices=['vit', 'resmlp_noln'], required=True)
+ p.add_argument('--output', type=str, required=True)
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--seed', type=int, default=42)
+ args = p.parse_args()
+
+ device = torch.device('cuda:0')
+ train_loader, test_loader = get_cifar10(128)
+ x_eval, y_eval = fixed_eval_buffer(test_loader, device, 1024)
+
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+
+ if args.arch == 'vit':
+ # Match ViT snapshot params
+ model = ViTMini(d_model=128, n_heads=4, num_blocks=4, num_classes=10).to(device)
+ fa_log = train_fa_vit(model, train_loader, x_eval, y_eval, device,
+ args.epochs, lr=1e-3, wd=0.05)
+ else:
+ # ResMLP without terminal LN — use the same class as the original no-outln experiment
+ from experiments.snapshot_evolution_no_outln import ResidualMLP_NoOutLN
+ model = ResidualMLP_NoOutLN(3072, 256, 10, 4).to(device)
+ fa_log = train_fa_resmlp(model, train_loader, x_eval, y_eval, device,
+ args.epochs, lr=1e-3, wd=0.01, no_outln=True)
+
+ with open(args.output, 'w') as f:
+ json.dump({'fa_log': fa_log, 'arch': args.arch, 'seed': args.seed}, f, indent=2)
+ print(f"Saved: {args.output}", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/snapshot_fa_only.py b/experiments/snapshot_fa_only.py
new file mode 100644
index 0000000..cdc69ae
--- /dev/null
+++ b/experiments/snapshot_fa_only.py
@@ -0,0 +1,38 @@
+"""Quick FA-only snapshot evolution. Reuses the full script's train_fa + diagnose."""
+import os, sys, json, argparse
+import numpy as np
+import torch
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from experiments.snapshot_evolution_residual_explosion import (
+ get_cifar10, fixed_eval_buffer, train_fa
+)
+from models.residual_mlp import ResidualMLP
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--output', type=str, required=True)
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--depth', type=int, default=4)
+ p.add_argument('--d_hidden', type=int, default=256)
+ args = p.parse_args()
+
+ device = torch.device('cuda:0')
+ train_loader, test_loader = get_cifar10(batch_size=128)
+ x_eval, y_eval = fixed_eval_buffer(test_loader, device, n_samples=1024)
+
+ L, d, C = args.depth, args.d_hidden, 10
+ print(f"FA snapshot: depth={L}, d={d}, seed={args.seed}, epochs={args.epochs}", flush=True)
+
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ model = ResidualMLP(3072, d, C, L).to(device)
+ fa_log = train_fa(model, train_loader, x_eval, y_eval, device,
+ args.epochs, 1e-3, 0.01, log_every=1)
+
+ with open(args.output, 'w') as f:
+ json.dump({'fa_log': fa_log, 'seed': args.seed, 'depth': L, 'd_hidden': d}, f, indent=2)
+ print(f"Saved: {args.output}", flush=True)
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/snapshot_fa_studentnet.py b/experiments/snapshot_fa_studentnet.py
new file mode 100644
index 0000000..887365c
--- /dev/null
+++ b/experiments/snapshot_fa_studentnet.py
@@ -0,0 +1,94 @@
+"""FA-only snapshot evolution for StudentNet (synthetic teacher-student)."""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from experiments.confirmatory_paper_experiments import (
+ StudentNet, TeacherNet, generate_synth_dataset, set_seed
+)
+from experiments.snapshot_synth_residual_explosion import diagnose_synth
+
+
+def train_fa_synth(model, train_loader, x_eval, y_eval, device, epochs, lr, wd):
+ """Canonical FA for StudentNet: mean reduction, grad before step, no clipping."""
+ d_hidden = model.d_hidden
+ L = model.num_blocks
+ Bs = [torch.randn(d_hidden, d_hidden, device=device) / np.sqrt(d_hidden) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = []
+ d0 = diagnose_synth(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0)
+ print(f" [FA] Ep 0: acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ # Head update — grad BEFORE step (old head)
+ hL_det = hiddens[-1].detach().requires_grad_(True)
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y) # mean reduction
+ head_opt.zero_grad()
+ loss_out.backward()
+ a_credit = hL_det.grad.detach()
+ head_opt.step()
+ # Top-down block updates, propagate credit after each
+ for l in range(L - 1, -1, -1):
+ h_l = hiddens[l].detach()
+ rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a_credit / rms)).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step() # no clipping
+ a_credit = (a_credit @ Bs[l]).detach()
+ # No embed for StudentNet (input is already d_hidden)
+ for s in all_sch: s.step()
+ d = diagnose_synth(model, x_eval, y_eval); d['epoch'] = ep; log.append(d)
+ if ep % 5 == 0 or ep in (1, epochs):
+ print(f" [FA] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} "
+ f"||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} "
+ f"acc={d['acc_eval']:.4f}", flush=True)
+ return log
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--output', type=str, required=True)
+ p.add_argument('--epochs', type=int, default=80)
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--alpha', type=float, default=1.0)
+ p.add_argument('--depth', type=int, default=4)
+ p.add_argument('--d_hidden', type=int, default=128)
+ args = p.parse_args()
+
+ device = torch.device('cuda:0')
+ L, d, C = args.depth, args.d_hidden, 10
+ set_seed(args.seed)
+ teacher = TeacherNet(d, L, C, args.alpha, seed=0).to(device)
+ X_tr, Y_tr = generate_synth_dataset(teacher, 50*256, d, device, seed=args.seed)
+ X_te, Y_te = generate_synth_dataset(teacher, 2000, d, device, seed=args.seed+10000)
+ train_loader = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=256, shuffle=True)
+
+ print(f"StudentNet FA: alpha={args.alpha}, L={L}, d={d}, seed={args.seed}", flush=True)
+ set_seed(args.seed)
+ model = StudentNet(d, C, L, args.alpha).to(device)
+ fa_log = train_fa_synth(model, train_loader, X_te.to(device), Y_te.to(device),
+ device, args.epochs, 1e-3, 0.01)
+
+ with open(args.output, 'w') as f:
+ json.dump({'fa_log': fa_log, 'seed': args.seed, 'alpha': args.alpha,
+ 'depth': L, 'd_hidden': d}, f, indent=2)
+ print(f"Saved: {args.output}", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/snapshot_synth_residual_explosion.py b/experiments/snapshot_synth_residual_explosion.py
new file mode 100644
index 0000000..3470667
--- /dev/null
+++ b/experiments/snapshot_synth_residual_explosion.py
@@ -0,0 +1,195 @@
+"""
+Synthetic snapshot evolution: per-epoch logging of ||h_l||_2 and ||BP grad||_2
+on a teacher-student StudentNet (NO out_ln) trained with BP vs DFA.
+
+Goal: test whether the residual-stream explosion observed in CIFAR ResidualMLP
+(pre-LN with out_ln before head) also happens in the synthetic StudentNet
+architecture (no out_ln; head reads h_L directly). If synthetic does NOT show
+the explosion, then out_ln is causally responsible for the CIFAR pathology and
+the paper's P4 claim narrows to "pre-LN architectures with terminal LN".
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 nohup python experiments/snapshot_synth_residual_explosion.py \
+ --output_dir results/snapshot_synth_v1 --epochs 80 --alpha 1.0 --depth 4 --seed 42 \
+ > results/snapshot_synth_v1/run_a1.0_s42.log 2>&1 &
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from metrics.credit_metrics import cosine_similarity_batch
+# Import the StudentNet/TeacherNet/generate_synth_dataset directly from confirmatory script
+from experiments.confirmatory_paper_experiments import (
+ StudentNet, TeacherNet, generate_synth_dataset, set_seed
+)
+
+
+def diagnose_synth(model, x_eval, y_eval, dfa_Bs=None):
+ was_training = model.training
+ model.eval()
+ L = model.num_blocks
+
+ with torch.no_grad():
+ _, hi = model(x_eval, return_hidden=True)
+ hidden_norms = [h.norm(dim=-1).median().item() for h in hi]
+
+ # BP grads
+ h_list = [x_eval.detach().requires_grad_(True)]
+ for block in model.blocks:
+ h_list.append(h_list[-1] + block(h_list[-1]))
+ logits = model.out_head(h_list[-1])
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, h_list)
+ bp_grad_l2 = [g.norm(dim=-1).median().item() for g in grads]
+ bp_grad_F = [g.norm().item() for g in grads]
+ bp_full = [g.detach() for g in grads]
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ loss_val = loss.item()
+
+ gamma_dfa = float('nan')
+ per_layer_gamma = []
+ if dfa_Bs is not None:
+ with torch.no_grad():
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(x_eval.size(0)), y_eval] -= 1.0
+ for l in range(L):
+ a_dfa = (e_T @ dfa_Bs[l].T).detach()
+ per_layer_gamma.append(cosine_similarity_batch(a_dfa, bp_full[l]))
+ gamma_dfa = float(np.mean(per_layer_gamma))
+
+ if was_training:
+ model.train()
+ return {
+ 'hidden_norms': hidden_norms,
+ 'bp_grad_per_sample_l2_med': bp_grad_l2,
+ 'bp_grad_F': bp_grad_F,
+ 'gamma_dfa': gamma_dfa,
+ 'gamma_dfa_per_layer': per_layer_gamma,
+ 'acc_eval': acc,
+ 'loss_eval': loss_val,
+ }
+
+
+def train_bp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd):
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ log = []
+ d0 = diagnose_synth(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0)
+ print(f" [BP] Ep 0: ||h_L||={d0['hidden_norms'][-1]:.3e} ||g||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ d = diagnose_synth(model, x_eval, y_eval); d['epoch'] = ep; log.append(d)
+ if ep % 5 == 0 or ep in (1, epochs):
+ print(f" [BP] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} ||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f}", flush=True)
+ return log
+
+
+def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd):
+ d_hidden = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ Bs = [torch.randn(d_hidden, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = []
+ d0 = diagnose_synth(model, x_eval, y_eval, dfa_Bs=Bs); d0['epoch'] = 0; log.append(d0)
+ print(f" [DFA] Ep 0: ||h_L||={d0['hidden_norms'][-1]:.3e} ||g||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ # head update via direct CE on head(hL)
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ # block updates via DFA local credit
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_dfa / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ for s in all_sch:
+ s.step()
+ d = diagnose_synth(model, x_eval, y_eval, dfa_Bs=Bs); d['epoch'] = ep; log.append(d)
+ if ep % 5 == 0 or ep in (1, epochs):
+ print(f" [DFA] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} ||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f} γ_dfa={d['gamma_dfa']:.4f}", flush=True)
+ return log
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--output_dir', type=str, default='results/snapshot_synth_v1')
+ p.add_argument('--epochs', type=int, default=80)
+ p.add_argument('--alpha', type=float, default=1.0)
+ p.add_argument('--depth', type=int, default=4)
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--d_hidden', type=int, default=128)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--wd', type=float, default=0.01)
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device('cuda:0')
+ print(f"device={device}, alpha={args.alpha}, depth={args.depth}, "
+ f"d_hidden={args.d_hidden}, epochs={args.epochs}, seed={args.seed}", flush=True)
+
+ set_seed(args.seed)
+ L, d, C = args.depth, args.d_hidden, 10
+ teacher = TeacherNet(d, L, C, args.alpha, seed=0).to(device)
+
+ n_train = 50 * 256
+ n_test = 2000
+ X_tr, Y_tr = generate_synth_dataset(teacher, n_train, d, device, seed=args.seed)
+ X_te, Y_te = generate_synth_dataset(teacher, n_test, d, device, seed=args.seed + 10000)
+ train_loader = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=256, shuffle=True)
+ x_eval, y_eval = X_te.to(device), Y_te.to(device)
+ print(f"train: {X_tr.shape}, test eval buffer: {x_eval.shape}", flush=True)
+
+ print("\n=== BP training ===", flush=True)
+ set_seed(args.seed)
+ bp_model = StudentNet(d, C, L, args.alpha).to(device)
+ bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd)
+
+ print("\n=== DFA training ===", flush=True)
+ set_seed(args.seed)
+ dfa_model = StudentNet(d, C, L, args.alpha).to(device)
+ dfa_log = train_dfa(dfa_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd)
+
+ out = {
+ 'config': vars(args),
+ 'depth': L, 'd_hidden': d, 'num_classes': C,
+ 'bp_log': bp_log,
+ 'dfa_log': dfa_log,
+ }
+ out_path = os.path.join(args.output_dir, f'snapshot_synth_a{args.alpha}_L{L}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(out, f, indent=2)
+ print(f"\nSaved {out_path}", flush=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/vit_frozen_blocks_baseline.py b/experiments/vit_frozen_blocks_baseline.py
new file mode 100644
index 0000000..8b53198
--- /dev/null
+++ b/experiments/vit_frozen_blocks_baseline.py
@@ -0,0 +1,177 @@
+"""
+Frozen-random-blocks baseline for ViT-Mini: train BP and DFA where the 4
+transformer blocks are randomly initialized and FROZEN (no parameter updates).
+Only patch_embed + cls_token + pos_embed + out_ln + out_head are trainable.
+
+This is the codex-round-6 control for the "DFA actually trains the transformer
+blocks" claim. If frozen-blocks DFA gets ≈ 24% (matching the trainable-blocks
+4-block ViT-Mini DFA acc), then the blocks are passengers — DFA's "24%" is
+coming from patch_embed + head learning routed via untrained block mixing.
+If frozen-blocks DFA stays much lower than 24%, then the trainable blocks
+are doing learned work.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 python experiments/vit_frozen_blocks_baseline.py
+"""
+import sys, os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import numpy as np
+
+from models.vit_mini import ViTMini
+
+
+def get_loaders(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ n = c = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(dev), y.to(dev)
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def freeze_blocks(model):
+ for p in model.blocks.parameters():
+ p.requires_grad_(False)
+ model.blocks.eval()
+
+
+def train_bp_frozen(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ m = ViTMini(num_blocks=4, d_model=128, n_heads=4).to(dev)
+ freeze_blocks(m)
+ n_trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
+ n_total = sum(p.numel() for p in m.parameters())
+ print(f"BP-frozen-blocks: {n_trainable}/{n_total} params trainable", flush=True)
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, m.parameters()), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ m.train()
+ m.blocks.eval() # keep blocks in eval mode (no dropout etc)
+ for x, y in train_loader:
+ x = x.to(dev); y = y.to(dev)
+ loss = F.cross_entropy(m(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(m, test_loader, dev)
+ print(f" BP-frozen ep {ep}: test_acc={acc:.4f}", flush=True)
+ return m
+
+
+def train_dfa_frozen(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05):
+ """4 transformer blocks frozen at random init.
+ Trainable: patch_embed, cls_token, pos_embed, out_ln, out_head.
+ DFA-style: head with true CE on cls token; embed (patch+cls+pos) with random feedback."""
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ m = ViTMini(num_blocks=4, d_model=128, n_heads=4).to(dev)
+ freeze_blocks(m)
+ n_trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
+ n_total = sum(p.numel() for p in m.parameters())
+ print(f"DFA-frozen-blocks: {n_trainable}/{n_total} params trainable", flush=True)
+
+ d_model, C = 128, 10
+ B0 = torch.randn(d_model, C, device=dev) / np.sqrt(C)
+ embed_opt = optim.AdamW(
+ list(m.patch_embed.parameters()) + [m.cls_token, m.pos_embed],
+ lr=lr, weight_decay=wd
+ )
+ head_opt = optim.AdamW(
+ list(m.out_head.parameters()) + list(m.out_ln.parameters()),
+ lr=lr, weight_decay=wd
+ )
+ sch1 = optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs)
+ sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ m.train()
+ m.blocks.eval()
+ for x, y in train_loader:
+ x = x.to(dev); y = y.to(dev)
+ with torch.no_grad():
+ logits, hi = m(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1
+ hL_det = hi[-1].detach()
+ # Head update via true CE on cls token
+ h_cls = m.out_ln(hL_det[:, 0])
+ head_opt.zero_grad()
+ F.cross_entropy(m.out_head(h_cls), y).backward()
+ head_opt.step()
+ # Embed update via DFA feedback
+ a0 = (e_T @ B0.T).detach()
+ rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = m.embed(x)
+ a0_b = a0.unsqueeze(1).expand_as(h0)
+ embed_loss = (h0 * (a0_b / rms.unsqueeze(1))).sum(-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+ sch1.step(); sch2.step()
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(m, test_loader, dev)
+ print(f" DFA-frozen ep {ep}: test_acc={acc:.4f}", flush=True)
+ return m
+
+
+def main():
+ import argparse
+ p = argparse.ArgumentParser()
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--epochs', type=int, default=30)
+ args = p.parse_args()
+
+ dev = torch.device('cuda:0')
+ print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True)
+ train_loader, test_loader = get_loaders(batch_size=128)
+
+ print(f"\n=== BP frozen-blocks baseline (4 random-init transformer blocks, frozen), seed={args.seed} ===", flush=True)
+ mb = train_bp_frozen(train_loader, test_loader, dev, epochs=args.epochs, seed=args.seed)
+ bp_acc = evaluate(mb, test_loader, dev)
+ print(f"FINAL BP-frozen-blocks acc: {bp_acc:.4f}", flush=True)
+
+ print(f"\n=== DFA frozen-blocks baseline, seed={args.seed} ===", flush=True)
+ md = train_dfa_frozen(train_loader, test_loader, dev, epochs=args.epochs, seed=args.seed)
+ dfa_acc = evaluate(md, test_loader, dev)
+ print(f"FINAL DFA-frozen-blocks acc: {dfa_acc:.4f}", flush=True)
+
+ print(f"\n=== Summary ===")
+ print(f"BP-frozen-blocks: {bp_acc:.4f} (chance=0.10)")
+ print(f"DFA-frozen-blocks: {dfa_acc:.4f}")
+ print(f"Compare to ViT-Mini 4-block trainable (3-seed avg): BP=0.792, DFA=0.237")
+ print(f"Compare to ViT-Mini 0-block (shallow baseline): BP=0.10, DFA=0.10")
+ print()
+ print("Interpretation:")
+ print(" If DFA-frozen-blocks ≈ 0.237: blocks are passengers, DFA is just learning patch_embed+head")
+ print(" If DFA-frozen-blocks << 0.237: trainable blocks ARE doing learned work")
+ print(" If DFA-frozen-blocks ~ 0.10: untrained blocks add no useful mixing (less informative)")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/vit_shallow_baseline.py b/experiments/vit_shallow_baseline.py
new file mode 100644
index 0000000..c030d74
--- /dev/null
+++ b/experiments/vit_shallow_baseline.py
@@ -0,0 +1,147 @@
+"""
+Shallow baseline for ViT-Mini: train BP and DFA on a 0-block ViT (just patch_embed
++ cls + pos + out_ln + out_head), to test whether the DFA accuracy on the full
+ViT is just exploiting the patch embedder + head.
+
+This is the codex-round-5 control for the "DFA actually trains the transformer
+blocks" claim. If shallow DFA acc ≈ 24% (matching the 4-block ViT-Mini DFA acc),
+then the blocks are passengers and the claim is too strong. If shallow DFA acc
+is much lower, then the blocks are doing real work.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 python experiments/vit_shallow_baseline.py
+"""
+import sys, os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import numpy as np
+
+from models.vit_mini import ViTMini
+
+
+def get_loaders(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ n = c = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(dev), y.to(dev)
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def train_bp_shallow(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ m = ViTMini(num_blocks=0, d_model=128, n_heads=4).to(dev)
+ print(f"BP-shallow: n_params={sum(p.numel() for p in m.parameters())}", flush=True)
+ opt = optim.AdamW(m.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ m.train()
+ for x, y in train_loader:
+ x = x.to(dev); y = y.to(dev)
+ loss = F.cross_entropy(m(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(m, test_loader, dev)
+ print(f" BP-shallow ep {ep}: test_acc={acc:.4f}", flush=True)
+ return m
+
+
+def train_dfa_shallow(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05):
+ """0-block ViT trained DFA-style: head with true CE on cls token,
+ embed (patch_embed + cls + pos) with random feedback `e_T @ B^T` from the head."""
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ m = ViTMini(num_blocks=0, d_model=128, n_heads=4).to(dev)
+ print(f"DFA-shallow: n_params={sum(p.numel() for p in m.parameters())}", flush=True)
+ d_model, C = 128, 10
+ B0 = torch.randn(d_model, C, device=dev) / np.sqrt(C)
+ embed_opt = optim.AdamW(
+ list(m.patch_embed.parameters()) + [m.cls_token, m.pos_embed],
+ lr=lr, weight_decay=wd
+ )
+ head_opt = optim.AdamW(
+ list(m.out_head.parameters()) + list(m.out_ln.parameters()),
+ lr=lr, weight_decay=wd
+ )
+ sch1 = optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs)
+ sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ m.train()
+ for x, y in train_loader:
+ x = x.to(dev); y = y.to(dev)
+ with torch.no_grad():
+ logits, hi = m(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1
+ hL_det = hi[-1].detach()
+ # Head update via true CE on cls token
+ h_cls = m.out_ln(hL_det[:, 0])
+ head_opt.zero_grad()
+ F.cross_entropy(m.out_head(h_cls), y).backward()
+ head_opt.step()
+ # Embed update via DFA-style local loss
+ a0 = (e_T @ B0.T).detach()
+ rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = m.embed(x) # (B, 65, d_model)
+ a0_b = a0.unsqueeze(1).expand_as(h0)
+ embed_loss = (h0 * (a0_b / rms.unsqueeze(1))).sum(-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+ sch1.step(); sch2.step()
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(m, test_loader, dev)
+ print(f" DFA-shallow ep {ep}: test_acc={acc:.4f}", flush=True)
+ return m
+
+
+def main():
+ dev = torch.device('cuda:0')
+ print(f"Device: {dev}", flush=True)
+ train_loader, test_loader = get_loaders(batch_size=128)
+
+ print("\n=== BP shallow baseline (ViT-Mini num_blocks=0) ===", flush=True)
+ mb = train_bp_shallow(train_loader, test_loader, dev, epochs=30, seed=42)
+ bp_acc = evaluate(mb, test_loader, dev)
+ print(f"FINAL BP-shallow acc: {bp_acc:.4f}", flush=True)
+
+ print("\n=== DFA shallow baseline (ViT-Mini num_blocks=0) ===", flush=True)
+ md = train_dfa_shallow(train_loader, test_loader, dev, epochs=30, seed=42)
+ dfa_acc = evaluate(md, test_loader, dev)
+ print(f"FINAL DFA-shallow acc: {dfa_acc:.4f}", flush=True)
+
+ print(f"\n=== Summary ===")
+ print(f"BP-shallow: {bp_acc:.4f} (chance=0.10)")
+ print(f"DFA-shallow: {dfa_acc:.4f}")
+ print(f"Compare to ViT-Mini 4-block (3-seed avg): BP=0.792, DFA=0.237")
+
+
+if __name__ == '__main__':
+ main()