From 2fa24acae8bb7f8c026db2f7fdade4a29b640d8d Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 8 Apr 2026 19:24:06 -0500 Subject: Sync experiment+protocol scripts with v2.32 corrected control values MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The pre-v2.31 unsourced values BP=0.609 and DFA=0.308 (which v2.31 fixed to 0.585 and 0.301 via matched 30-ep controls) were also hardcoded as "compare to" comments in 5 helper scripts: experiments/bp_with_penalty_control.py experiments/dfa_residual_penalty_test.py experiments/resmlp_frozen_blocks_baseline.py protocol/examples/threshold_d_sensitivity.py protocol/examples/plot_penalty_rescue.py These are non-paper-input scripts (their output goes to stdout, not to the paper), so the stale values didn't cause numerical errors in the paper itself. But the original v2.31 BP+pen=0.609 unsourced number bug came from exactly this kind of hardcoded "for-comparison" comment that was never measured. Updating them now to remove the same trap from future runs. Each script now references the matched 30-ep 3-seed values from results/bp_no_penalty_30ep, results/dfa_no_penalty_30ep, results/ dfa_pen_short, and results/bp_with_penalty. protocol/EVIDENCE_SUMMARY.md and PAPER_OUTLINE.md still have stale numbers — these are project scratch documents and not user-facing. Deferred to a separate sweep if needed. Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/bp_with_penalty_control.py | 8 +- experiments/dfa_residual_penalty_test.py | 10 +- experiments/resmlp_frozen_blocks_baseline.py | 202 +++++++++++++++++++++++++++ protocol/examples/plot_penalty_rescue.py | 4 +- protocol/examples/threshold_d_sensitivity.py | 14 +- 5 files changed, 221 insertions(+), 17 deletions(-) create mode 100644 experiments/resmlp_frozen_blocks_baseline.py diff --git a/experiments/bp_with_penalty_control.py b/experiments/bp_with_penalty_control.py index b986dee..07ee1f1 100644 --- a/experiments/bp_with_penalty_control.py +++ b/experiments/bp_with_penalty_control.py @@ -117,10 +117,10 @@ def main(): log = train_bp_with_penalty(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, args.lam) final_acc = evaluate(m, test_loader, dev) print(f"\nFINAL test acc: {final_acc:.4f}", flush=True) - print(f"Compare to:") - print(f" BP-trainable (3-seed mean): 0.609") - print(f" Penalized DFA lam=1e-2: 0.363") - print(f" DFA-shallow: 0.349") + print(f"Compare to (matched 30-epoch 3-seed values, see paper v2.32):") + print(f" BP-trainable no-pen (3-seed): 0.585 ± 0.001") + print(f" Penalized DFA lam=1e-2: 0.360 ± 0.001") + print(f" DFA-shallow (frozen blocks): 0.349 ± 0.002") margin = (final_acc - 0.349) * 100 print(f"\nMargin vs DFA-shallow baseline: {margin:+.2f} pp") if margin > 25: diff --git a/experiments/dfa_residual_penalty_test.py b/experiments/dfa_residual_penalty_test.py index 3fa5466..f2b43ec 100644 --- a/experiments/dfa_residual_penalty_test.py +++ b/experiments/dfa_residual_penalty_test.py @@ -182,11 +182,11 @@ def main(): final_test = evaluate(m, test_loader, dev) print(f"\nFINAL test acc: {final_test:.4f}") - print(f"Compare to:") - print(f" DFA-vanilla (3-seed mean): 0.308") - print(f" DFA-shallow (3-seed mean): 0.349") - print(f" DFA-frozen (3-seed mean): 0.349") - print(f" BP-trainable (3-seed mean): 0.609") + print(f"Compare to (matched 30-epoch 3-seed values, see paper v2.32):") + print(f" DFA-vanilla 30ep (3-seed): 0.301 ± 0.005") + print(f" DFA-shallow / DFA-frozen: 0.349 ± 0.002") + print(f" BP-trainable no-pen 30ep: 0.585 ± 0.001") + print(f" BP+pen lam=1e-2 30ep: 0.532 ± 0.006") out = {'config': vars(args), 'final_test_acc': final_test, 'log': log} out_path = os.path.join(args.output_dir, f'dfa_pen_lam{args.lam}_s{args.seed}.json') diff --git a/experiments/resmlp_frozen_blocks_baseline.py b/experiments/resmlp_frozen_blocks_baseline.py new file mode 100644 index 0000000..c330be2 --- /dev/null +++ b/experiments/resmlp_frozen_blocks_baseline.py @@ -0,0 +1,202 @@ +""" +Frozen-blocks and shallow baselines for the 4-block d=256 ResidualMLP on CIFAR-10. +This is the codex-round-8 control for ResMLP, parallel to the ViT-Mini frozen-blocks +experiment that walked back the "DFA trains a 4-block ViT" claim. + +Conditions (4 per seed): + - BP shallow (num_blocks=0, just embed -> out_ln -> out_head) + - BP frozen-blocks (num_blocks=4, blocks frozen at random init, only embed/LN/head trainable) + - DFA shallow (num_blocks=0) + - DFA frozen-blocks (num_blocks=4, blocks frozen) + +If frozen ≈ trainable for DFA: DFA-on-ResMLP also has the same "blocks are passengers" +problem as ViT-Mini, and the strongest remaining DFA performance result in the paper +falls. If frozen << trainable: DFA on ResMLP IS doing meaningful block training, and +the contrast with ViT becomes the most interesting result. + +Usage: + CUDA_VISIBLE_DEVICES=2 python experiments/resmlp_frozen_blocks_baseline.py --seed 42 +""" +import sys, os, argparse +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import numpy as np + +from models.residual_mlp import ResidualMLP + + +def get_loaders(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return ( + DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), + ) + + +def evaluate(model, loader, dev): + model.eval() + n = c = 0 + with torch.no_grad(): + for x, y in loader: + x = x.view(x.size(0), -1).to(dev); y = y.to(dev) + preds = model(x).argmax(-1) + c += (preds == y).sum().item() + n += x.size(0) + return c / n + + +def freeze_blocks(model): + for p in model.blocks.parameters(): + p.requires_grad_(False) + + +def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd, label): + """Standard BP. Filters optimizer to requires_grad params.""" + opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(dev); y = y.to(dev) + loss = F.cross_entropy(model(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 10 == 0 or ep == 1 or ep == epochs: + acc = evaluate(model, test_loader, dev) + print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True) + return model + + +def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label): + """DFA-style: head with true CE, embed (and unfrozen blocks if any) with random feedback. + For frozen-blocks: blocks are skipped. For trainable blocks not used here. + For num_blocks=0 (shallow): only embed/head are updated. + """ + d_hidden = model.d_hidden + L = model.num_blocks + C = 10 + Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(max(L, 1))] + + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd + ) + sch1 = optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs) + sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs) + + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(dev); y = y.to(dev) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1 + hL_det = hiddens[-1].detach() + # Head update via true CE + logits_out = model.out_head(model.out_ln(hL_det)) + head_opt.zero_grad() + F.cross_entropy(logits_out, y).backward() + head_opt.step() + # Embed update via DFA feedback + a0 = (e_T @ Bs[0].T).detach() + rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_loss = (h0 * (a0 / rms)).sum(-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + sch1.step(); sch2.step() + if ep % 10 == 0 or ep == 1 or ep == epochs: + acc = evaluate(model, test_loader, dev) + print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True) + return model + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--d_hidden', type=int, default=256) + args = parser.parse_args() + + dev = torch.device('cuda:0') + print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True) + train_loader, test_loader = get_loaders(batch_size=128) + + results = {} + input_dim = 32 * 32 * 3 + C = 10 + + # Condition 1: BP shallow (num_blocks=0) + print(f"\n=== BP shallow (ResMLP num_blocks=0), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = ResidualMLP(input_dim, args.d_hidden, C, 0).to(dev) + print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) + train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-shallow') + results['bp_shallow'] = evaluate(m, test_loader, dev) + print(f"FINAL BP-shallow: {results['bp_shallow']:.4f}", flush=True) + + # Condition 2: BP frozen-blocks (num_blocks=4 frozen) + print(f"\n=== BP frozen-blocks (ResMLP num_blocks=4, blocks frozen), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = ResidualMLP(input_dim, args.d_hidden, C, 4).to(dev) + freeze_blocks(m) + print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) + train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-frozen') + results['bp_frozen'] = evaluate(m, test_loader, dev) + print(f"FINAL BP-frozen-blocks: {results['bp_frozen']:.4f}", flush=True) + + # Condition 3: DFA shallow (num_blocks=0) + print(f"\n=== DFA shallow (ResMLP num_blocks=0), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = ResidualMLP(input_dim, args.d_hidden, C, 0).to(dev) + print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) + train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow') + results['dfa_shallow'] = evaluate(m, test_loader, dev) + print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True) + + # Condition 4: DFA frozen-blocks (num_blocks=4 frozen) + print(f"\n=== DFA frozen-blocks (ResMLP num_blocks=4, blocks frozen), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = ResidualMLP(input_dim, args.d_hidden, C, 4).to(dev) + freeze_blocks(m) + print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) + train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen') + results['dfa_frozen'] = evaluate(m, test_loader, dev) + print(f"FINAL DFA-frozen-blocks: {results['dfa_frozen']:.4f}", flush=True) + + print(f"\n=== ResMLP frozen/shallow baseline summary, seed={args.seed} ===") + print(f" BP-shallow: {results['bp_shallow']:.4f}") + print(f" BP-frozen: {results['bp_frozen']:.4f}") + print(f" DFA-shallow: {results['dfa_shallow']:.4f}") + print(f" DFA-frozen: {results['dfa_frozen']:.4f}") + print(f"") + print(f"Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep") + print(f"") + print(f"Interpretation:") + print(f" If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT") + print(f" If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT)") + + +if __name__ == '__main__': + main() diff --git a/protocol/examples/plot_penalty_rescue.py b/protocol/examples/plot_penalty_rescue.py index 37b0fa9..fff300e 100644 --- a/protocol/examples/plot_penalty_rescue.py +++ b/protocol/examples/plot_penalty_rescue.py @@ -10,7 +10,7 @@ Data sources: - vanilla DFA trajectory: results/snapshot_evolution_v2/snapshot_evolution_s42.json - penalized DFA (lam=1e-2): results/dfa_residual_penalty/dfa_pen_lam0.01_s42.json - DFA-shallow baseline 3-seed mean (drawn as horizontal line): 0.349 - - BP-trainable 3-seed mean: 0.609 + - BP-trainable 3-seed mean: 0.6147 (100 ep) / 0.585 (matched 30 ep) Run: python -m protocol.examples.plot_penalty_rescue @@ -91,7 +91,7 @@ def main(): ax.plot([e["epoch"] for e in penalty], [e["acc_eval"] for e in penalty], label=r"DFA + $\lambda \|f_l\|^2$", color="C2", lw=2, marker="o", markersize=4) ax.axhline(0.349, color="k", linestyle="--", lw=1.2, label="DFA-shallow 0.349") - ax.axhline(0.609, color="C0", linestyle=":", lw=1, label="BP-trainable 0.609") + ax.axhline(0.6147, color="C0", linestyle=":", lw=1, label="BP-trainable 100ep 0.615") ax.set_xlabel("epoch", fontsize=10) ax.set_ylabel("test acc", fontsize=10) ax.set_title("(d) headline accuracy", fontsize=11) diff --git a/protocol/examples/threshold_d_sensitivity.py b/protocol/examples/threshold_d_sensitivity.py index d3f2c58..065efc7 100644 --- a/protocol/examples/threshold_d_sensitivity.py +++ b/protocol/examples/threshold_d_sensitivity.py @@ -22,13 +22,15 @@ REPO_ROOT = os.path.dirname( def main(): # 3-seed mean accuracies on 4-block d=256 ResMLP CIFAR-10 + # Updated v2.32 with matched 30-epoch controls conditions = [ - ("BP-trainable", 0.609, 0.004), - ("DFA-shallow", 0.349, 0.002), - ("DFA-vanilla", 0.308, 0.014), - ("DFA-pen lam=1e-3", 0.372, None), # 1 seed - ("DFA-pen lam=1e-2", 0.363, 0.0007), - ("DFA-frozen-rand", 0.349, 0.002), + ("BP-trainable 100ep", 0.6147, 0.004), # protocol_audit + ("BP-trainable 30ep", 0.585, 0.001), # results/bp_no_penalty_30ep + ("BP+pen 30ep lam=1e-2", 0.532, 0.006), # results/bp_with_penalty + ("DFA-shallow", 0.349, 0.002), # frozen baseline + ("DFA-vanilla 100ep", 0.306, 0.006), # protocol_audit + ("DFA-vanilla 30ep", 0.301, 0.005), # results/dfa_no_penalty_30ep + ("DFA+pen 30ep lam=1e-2", 0.360, 0.001), # results/dfa_pen_short ] shallow_acc = 0.349 -- cgit v1.2.3