summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 19:24:06 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 19:24:06 -0500
commit2fa24acae8bb7f8c026db2f7fdade4a29b640d8d (patch)
tree98bf266ac07a1d6974769262dff916553223612f /experiments
parentcebc4c4a81809a982a16dd07da41487aa2f30322 (diff)
Sync experiment+protocol scripts with v2.32 corrected control values
The pre-v2.31 unsourced values BP=0.609 and DFA=0.308 (which v2.31 fixed to 0.585 and 0.301 via matched 30-ep controls) were also hardcoded as "compare to" comments in 5 helper scripts: experiments/bp_with_penalty_control.py experiments/dfa_residual_penalty_test.py experiments/resmlp_frozen_blocks_baseline.py protocol/examples/threshold_d_sensitivity.py protocol/examples/plot_penalty_rescue.py These are non-paper-input scripts (their output goes to stdout, not to the paper), so the stale values didn't cause numerical errors in the paper itself. But the original v2.31 BP+pen=0.609 unsourced number bug came from exactly this kind of hardcoded "for-comparison" comment that was never measured. Updating them now to remove the same trap from future runs. Each script now references the matched 30-ep 3-seed values from results/bp_no_penalty_30ep, results/dfa_no_penalty_30ep, results/ dfa_pen_short, and results/bp_with_penalty. protocol/EVIDENCE_SUMMARY.md and PAPER_OUTLINE.md still have stale numbers — these are project scratch documents and not user-facing. Deferred to a separate sweep if needed. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/bp_with_penalty_control.py8
-rw-r--r--experiments/dfa_residual_penalty_test.py10
-rw-r--r--experiments/resmlp_frozen_blocks_baseline.py202
3 files changed, 211 insertions, 9 deletions
diff --git a/experiments/bp_with_penalty_control.py b/experiments/bp_with_penalty_control.py
index b986dee..07ee1f1 100644
--- a/experiments/bp_with_penalty_control.py
+++ b/experiments/bp_with_penalty_control.py
@@ -117,10 +117,10 @@ def main():
log = train_bp_with_penalty(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, args.lam)
final_acc = evaluate(m, test_loader, dev)
print(f"\nFINAL test acc: {final_acc:.4f}", flush=True)
- print(f"Compare to:")
- print(f" BP-trainable (3-seed mean): 0.609")
- print(f" Penalized DFA lam=1e-2: 0.363")
- print(f" DFA-shallow: 0.349")
+ print(f"Compare to (matched 30-epoch 3-seed values, see paper v2.32):")
+ print(f" BP-trainable no-pen (3-seed): 0.585 ± 0.001")
+ print(f" Penalized DFA lam=1e-2: 0.360 ± 0.001")
+ print(f" DFA-shallow (frozen blocks): 0.349 ± 0.002")
margin = (final_acc - 0.349) * 100
print(f"\nMargin vs DFA-shallow baseline: {margin:+.2f} pp")
if margin > 25:
diff --git a/experiments/dfa_residual_penalty_test.py b/experiments/dfa_residual_penalty_test.py
index 3fa5466..f2b43ec 100644
--- a/experiments/dfa_residual_penalty_test.py
+++ b/experiments/dfa_residual_penalty_test.py
@@ -182,11 +182,11 @@ def main():
final_test = evaluate(m, test_loader, dev)
print(f"\nFINAL test acc: {final_test:.4f}")
- print(f"Compare to:")
- print(f" DFA-vanilla (3-seed mean): 0.308")
- print(f" DFA-shallow (3-seed mean): 0.349")
- print(f" DFA-frozen (3-seed mean): 0.349")
- print(f" BP-trainable (3-seed mean): 0.609")
+ print(f"Compare to (matched 30-epoch 3-seed values, see paper v2.32):")
+ print(f" DFA-vanilla 30ep (3-seed): 0.301 ± 0.005")
+ print(f" DFA-shallow / DFA-frozen: 0.349 ± 0.002")
+ print(f" BP-trainable no-pen 30ep: 0.585 ± 0.001")
+ print(f" BP+pen lam=1e-2 30ep: 0.532 ± 0.006")
out = {'config': vars(args), 'final_test_acc': final_test, 'log': log}
out_path = os.path.join(args.output_dir, f'dfa_pen_lam{args.lam}_s{args.seed}.json')
diff --git a/experiments/resmlp_frozen_blocks_baseline.py b/experiments/resmlp_frozen_blocks_baseline.py
new file mode 100644
index 0000000..c330be2
--- /dev/null
+++ b/experiments/resmlp_frozen_blocks_baseline.py
@@ -0,0 +1,202 @@
+"""
+Frozen-blocks and shallow baselines for the 4-block d=256 ResidualMLP on CIFAR-10.
+This is the codex-round-8 control for ResMLP, parallel to the ViT-Mini frozen-blocks
+experiment that walked back the "DFA trains a 4-block ViT" claim.
+
+Conditions (4 per seed):
+ - BP shallow (num_blocks=0, just embed -> out_ln -> out_head)
+ - BP frozen-blocks (num_blocks=4, blocks frozen at random init, only embed/LN/head trainable)
+ - DFA shallow (num_blocks=0)
+ - DFA frozen-blocks (num_blocks=4, blocks frozen)
+
+If frozen ≈ trainable for DFA: DFA-on-ResMLP also has the same "blocks are passengers"
+problem as ViT-Mini, and the strongest remaining DFA performance result in the paper
+falls. If frozen << trainable: DFA on ResMLP IS doing meaningful block training, and
+the contrast with ViT becomes the most interesting result.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 python experiments/resmlp_frozen_blocks_baseline.py --seed 42
+"""
+import sys, os, argparse
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import numpy as np
+
+from models.residual_mlp import ResidualMLP
+
+
+def get_loaders(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ n = c = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def freeze_blocks(model):
+ for p in model.blocks.parameters():
+ p.requires_grad_(False)
+
+
+def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd, label):
+ """Standard BP. Filters optimizer to requires_grad params."""
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(model, test_loader, dev)
+ print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True)
+ return model
+
+
+def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label):
+ """DFA-style: head with true CE, embed (and unfrozen blocks if any) with random feedback.
+ For frozen-blocks: blocks are skipped. For trainable blocks not used here.
+ For num_blocks=0 (shallow): only embed/head are updated.
+ """
+ d_hidden = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(max(L, 1))]
+
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd
+ )
+ sch1 = optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs)
+ sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)
+
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1
+ hL_det = hiddens[-1].detach()
+ # Head update via true CE
+ logits_out = model.out_head(model.out_ln(hL_det))
+ head_opt.zero_grad()
+ F.cross_entropy(logits_out, y).backward()
+ head_opt.step()
+ # Embed update via DFA feedback
+ a0 = (e_T @ Bs[0].T).detach()
+ rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a0 / rms)).sum(-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+ sch1.step(); sch2.step()
+ if ep % 10 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(model, test_loader, dev)
+ print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True)
+ return model
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ args = parser.parse_args()
+
+ dev = torch.device('cuda:0')
+ print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True)
+ train_loader, test_loader = get_loaders(batch_size=128)
+
+ results = {}
+ input_dim = 32 * 32 * 3
+ C = 10
+
+ # Condition 1: BP shallow (num_blocks=0)
+ print(f"\n=== BP shallow (ResMLP num_blocks=0), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = ResidualMLP(input_dim, args.d_hidden, C, 0).to(dev)
+ print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
+ train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-shallow')
+ results['bp_shallow'] = evaluate(m, test_loader, dev)
+ print(f"FINAL BP-shallow: {results['bp_shallow']:.4f}", flush=True)
+
+ # Condition 2: BP frozen-blocks (num_blocks=4 frozen)
+ print(f"\n=== BP frozen-blocks (ResMLP num_blocks=4, blocks frozen), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = ResidualMLP(input_dim, args.d_hidden, C, 4).to(dev)
+ freeze_blocks(m)
+ print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
+ train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-frozen')
+ results['bp_frozen'] = evaluate(m, test_loader, dev)
+ print(f"FINAL BP-frozen-blocks: {results['bp_frozen']:.4f}", flush=True)
+
+ # Condition 3: DFA shallow (num_blocks=0)
+ print(f"\n=== DFA shallow (ResMLP num_blocks=0), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = ResidualMLP(input_dim, args.d_hidden, C, 0).to(dev)
+ print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
+ train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow')
+ results['dfa_shallow'] = evaluate(m, test_loader, dev)
+ print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True)
+
+ # Condition 4: DFA frozen-blocks (num_blocks=4 frozen)
+ print(f"\n=== DFA frozen-blocks (ResMLP num_blocks=4, blocks frozen), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = ResidualMLP(input_dim, args.d_hidden, C, 4).to(dev)
+ freeze_blocks(m)
+ print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
+ train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen')
+ results['dfa_frozen'] = evaluate(m, test_loader, dev)
+ print(f"FINAL DFA-frozen-blocks: {results['dfa_frozen']:.4f}", flush=True)
+
+ print(f"\n=== ResMLP frozen/shallow baseline summary, seed={args.seed} ===")
+ print(f" BP-shallow: {results['bp_shallow']:.4f}")
+ print(f" BP-frozen: {results['bp_frozen']:.4f}")
+ print(f" DFA-shallow: {results['dfa_shallow']:.4f}")
+ print(f" DFA-frozen: {results['dfa_frozen']:.4f}")
+ print(f"")
+ print(f"Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep")
+ print(f"")
+ print(f"Interpretation:")
+ print(f" If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT")
+ print(f" If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT)")
+
+
+if __name__ == '__main__':
+ main()