diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-22 23:46:33 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-22 23:46:33 -0500 |
| commit | 05c935ab03ee0bdb8597d19466192dfb92ee889d (patch) | |
| tree | f8411f1724ed6379c392f0bd7629c83cb4eea534 /experiments | |
| parent | 7aa7123e190cbae3f6ce55050666efcc2ce00796 (diff) | |
Add vanilla FA (Lillicrap 2016) implementation + full experiment suite
PAPER-CHANGING FINDING: FA is dramatically different from DFA on the
same architecture. FA has genuine deep credit quality where DFA has none.
Implementation:
- experiments/cifar_resmlp.py: added train_fa() + FA diagnostic support
FA uses sequential backward credit propagation with d×d random matrices
(a_l = B_l @ a_{l+1}) instead of DFA's direct output-error projection
(a_l = B_l^T @ e_T). Same local loss form <f_l, a_l>.
Core results (A-H, 100ep 3-seed d=256 terminal-LN ResMLP):
FA main audit: 0.401 ± 0.009 (DFA: 0.306 ± 0.008) +9.5 pp
FA vs frozen: +5.2 pp ABOVE baseline (DFA: -4.3 pp below)
FA deep cos: +0.33 (DFA: ~0 degenerate)
FA ||h_L||: ~10^5 (DFA: ~5×10^8) 3 OOM less growth
FA ||g_L||: ~10^-6 meaningful (DFA: ~10^-10 floor)
Mode 1(b) fires: NO for FA; YES for DFA
FA+pen lam=1e-2: 0.369 ± 0.003 (DFA+pen: 0.360 ± 0.002)
FA+pen lam=1e-4: 0.377 ± 0.006 (DFA+pen lam=1e-4: 0.360)
At lam=1e-4, FA already has deep cos +0.30 while DFA has -0.02
FA random-target: acc 0.12 (chance), h_L=1.3e5 (DFA: 1.7e8)
FA early 5ep: deep cos already +0.32 (DFA ep1: -0.008)
Extension results (d=512 depth sweep, 100ep, s42):
L=2: FA 0.350, cos +0.96 (DFA: n/a)
L=4: FA 0.424, cos +0.29 (DFA: n/a)
L=6: FA 0.401, cos +0.16 (DFA: n/a)
L=8: FA 0.409, cos +0.11 (DFA: 0.306, cos -0.0001)
L=12: FA 0.404, cos +0.09 (DFA: 0.309, cos -0.0001)
FA deep cos is positive at EVERY depth; DFA is ~0 everywhere.
FA accuracy exceeds DFA by 5-10 pp at L=8 and L=12.
This is the strongest empirical support for the Mode 2 → Mode 1
hypothesis: same local loss, same architecture, same optimizer —
only the credit signal differs. FA's sequential propagation produces
much better per-layer credit (cos +0.33 vs ~0), which prevents the
catastrophic activation growth that DFA exhibits.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/cifar_resmlp.py | 142 |
1 files changed, 140 insertions, 2 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py index 7aba671..05a355d 100644 --- a/experiments/cifar_resmlp.py +++ b/experiments/cifar_resmlp.py @@ -229,6 +229,116 @@ def train_dfa(model, train_loader, test_loader, device, args): # ============================================================================= +# Vanilla FA (Lillicrap 2016) +# ============================================================================= +def train_fa(model, train_loader, test_loader, device, args): + """ + Vanilla Feedback Alignment (Lillicrap et al. 2016). + Unlike DFA (which projects output error directly to each layer via + a_l = B_l^T @ e_T), FA propagates credit sequentially backward through + the block stack using fixed random d×d feedback matrices: + a_L = exact gradient at h_L through out_head + out_ln + a_l = B_l @ a_{l+1} (random d×d replaces block Jacobian transpose) + Each block is updated with the same local loss as DFA: <f_l(h_l), a_l>. + """ + d = model.d_hidden + num_classes = args.num_classes + L = model.num_blocks + + # Fixed random feedback matrices: d × d (one per block). + # These replace the transpose of the block Jacobian dF_l/dh_l in the + # backward pass. Contrast with DFA's B_l which are d × num_classes. + Bs = [torch.randn(d, d, device=device) / np.sqrt(d) for _ in range(L)] + + # Same optimizer structure as DFA + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + if getattr(args, 'random_targets', False): + y = torch.randint(0, args.num_classes, y.shape, device=device) + batch = x.size(0) + + # Forward pass + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + + # 1. Update output head (exact CE gradient, h_L detached) + hL_det = hiddens[-1].detach().requires_grad_(True) + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # Exact gradient at h_L — FA's starting credit signal + a_credit = hL_det.grad.detach() # (batch, d) + + # 2. Update each block with FA credit (backward sequential) + for l in range(L - 1, -1, -1): + h_l = hiddens[l].detach() + # Normalize credit + rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_credit / rms + # Local surrogate (same form as DFA) + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + if getattr(args, 'penalty_lam', 0.0) > 0.0: + local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() + + # Propagate credit backward: FA replaces block Jacobian^T with B_l + a_credit = (a_credit @ Bs[l]).detach() + + # 3. Update embedding with FA credit at h_0 + rms_0 = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_credit / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for s in all_schedulers: + s.step() + + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + if epoch % 10 == 0 or epoch == 1: + print(f" [FA] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}") + + return log, Bs + + +# ============================================================================= # State Bridge # ============================================================================= def train_state_bridge(model, train_loader, test_loader, device, args): @@ -621,6 +731,18 @@ def compute_diagnostics(model, method_name, test_loader, device, args, 'bp_grad_norms_per_layer': bp_grad_norms_per_layer, } + # Pre-compute FA credits if needed (sequential backward from exact h_L gradient) + _fa_credits = None + if method_name == 'fa' and dfa_Bs is not None: + hL_req = hiddens[L].detach().requires_grad_(True) + logits_fa = model.out_head(model.out_ln(hL_req)) + loss_fa = F.cross_entropy(logits_fa, y, reduction='sum') + _fa_a_L = torch.autograd.grad(loss_fa, hL_req)[0].detach() + _fa_credits = [None] * L + _fa_credits[L - 1] = _fa_a_L + for ll in range(L - 2, -1, -1): + _fa_credits[ll] = (_fa_credits[ll + 1] @ dfa_Bs[ll + 1]).detach() + for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) @@ -630,6 +752,8 @@ def compute_diagnostics(model, method_name, test_loader, device, args, a_l = bp_grads[l] elif method_name == 'dfa': a_l = (e_T @ dfa_Bs[l].T).detach() + elif method_name == 'fa': + a_l = _fa_credits[l] elif method_name == 'state_bridge': h_l_req = h_l.clone().requires_grad_(True) pred_hL = state_predictor(h_l_req, t_l, s) @@ -720,6 +844,20 @@ def run_experiment(args): seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift} print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}") + # ---- FA (vanilla Feedback Alignment, Lillicrap 2016) ---- + if 'fa' in methods_to_run: + print("\n--- FA ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_fa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_fa = {n: p.clone().detach() for n, p in model_fa.named_parameters()} + fa_log, fa_Bs = train_fa(model_fa, train_loader, test_loader, device, args) + fa_diag = compute_diagnostics(model_fa, 'fa', test_loader, device, args, dfa_Bs=fa_Bs) + fa_drift = feature_drift(init_fa, {n: p.detach() for n, p in model_fa.named_parameters()}) + seed_results['fa'] = {'log': fa_log, 'diagnostics': fa_diag, 'drift': fa_drift} + print(f" Final test acc: {fa_log['test_acc'][-1]:.4f}") + # ---- State Bridge ---- if 'state_bridge' in methods_to_run: print("\n--- State Bridge ---") @@ -793,8 +931,8 @@ def main(): parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456]) parser.add_argument('--gpu', type=int, default=1) parser.add_argument('--output_dir', type=str, default='results/cifar10') - parser.add_argument('--methods', type=str, nargs='+', default=['bp', 'dfa', 'state_bridge', 'credit_bridge'], - help='Subset of methods to run.') + parser.add_argument('--methods', type=str, nargs='+', default=['bp', 'dfa', 'fa', 'state_bridge', 'credit_bridge'], + help='Subset of methods to run. fa = vanilla Feedback Alignment (Lillicrap 2016).') parser.add_argument('--random_targets', action='store_true', help='Replace each minibatch label with i.i.d. random class targets (Mode 1 data-agnostic test).') parser.add_argument('--penalty_lam', type=float, default=0.0, |
