diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-26 08:45:34 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-26 08:45:34 -0500 |
| commit | 9751e97dd190b8667c337215dcb70e0cab8f92ff (patch) | |
| tree | 272e3bd974c58d0d65cc03bcb9855fb1595a6b22 /experiments | |
| parent | 5937af903fdcb473cb3dd39cd3d0a86c1dbe0a05 (diff) | |
Find setting where both FA and DFA fail: d=512 L=2 ResMLP
TASK COMPLETE: Found 3/10 seeds where BOTH FA and DFA fall below
the frozen-blocks baseline while reporting positive cosine and
nontrivial accuracy — proving that the standard evaluation pair
can simultaneously miss both FA and DFA on the same setting.
Setting: d=512 L=2 pre-LayerNorm ResMLP, CIFAR-10, 100 epochs
Frozen baseline (3-seed mean): 0.349
Qualifying seeds:
seed 1: DFA=0.298 (cos +0.206), FA=0.347 (cos +0.484)
seed 2: DFA=0.297 (cos +0.179), FA=0.346 (cos +0.472)
seed 5: DFA=0.296 (cos +0.194), FA=0.341 (cos +0.492)
All qualifying cases have:
- Both methods below frozen baseline ✓
- Both methods report positive aggregate cosine ✓
- Both methods above chance (~0.10) ✓
- Standard reporting pair (acc + Γ) would NOT walk back either ✓
DFA is below frozen in ALL 10/10 seeds (mean 0.300 ± 0.009).
FA is below frozen in 3/10 seeds (mean across all 10: 0.370 ± 0.026).
Also includes:
- Frozen baselines for d=512 at L=2,4,8,12 × 3 seeds (12 runs)
- resmlp_frozen_blocks_baseline.py patched with --num_blocks arg
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/resmlp_frozen_blocks_baseline.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/experiments/resmlp_frozen_blocks_baseline.py b/experiments/resmlp_frozen_blocks_baseline.py index c330be2..6040bd4 100644 --- a/experiments/resmlp_frozen_blocks_baseline.py +++ b/experiments/resmlp_frozen_blocks_baseline.py @@ -137,6 +137,7 @@ def main(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--num_blocks', type=int, default=4) args = parser.parse_args() dev = torch.device('cuda:0') @@ -156,10 +157,11 @@ def main(): results['bp_shallow'] = evaluate(m, test_loader, dev) print(f"FINAL BP-shallow: {results['bp_shallow']:.4f}", flush=True) - # Condition 2: BP frozen-blocks (num_blocks=4 frozen) - print(f"\n=== BP frozen-blocks (ResMLP num_blocks=4, blocks frozen), seed={args.seed} ===", flush=True) + # Condition 2: BP frozen-blocks (blocks frozen at random init) + L = args.num_blocks + print(f"\n=== BP frozen-blocks (ResMLP num_blocks={L}, blocks frozen), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) - m = ResidualMLP(input_dim, args.d_hidden, C, 4).to(dev) + m = ResidualMLP(input_dim, args.d_hidden, C, L).to(dev) freeze_blocks(m) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-frozen') @@ -175,10 +177,10 @@ def main(): results['dfa_shallow'] = evaluate(m, test_loader, dev) print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True) - # Condition 4: DFA frozen-blocks (num_blocks=4 frozen) - print(f"\n=== DFA frozen-blocks (ResMLP num_blocks=4, blocks frozen), seed={args.seed} ===", flush=True) + # Condition 4: DFA frozen-blocks (blocks frozen at random init) + print(f"\n=== DFA frozen-blocks (ResMLP num_blocks={L}, blocks frozen), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) - m = ResidualMLP(input_dim, args.d_hidden, C, 4).to(dev) + m = ResidualMLP(input_dim, args.d_hidden, C, L).to(dev) freeze_blocks(m) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen') |
