summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-22 23:46:33 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-22 23:46:33 -0500
commit05c935ab03ee0bdb8597d19466192dfb92ee889d (patch)
treef8411f1724ed6379c392f0bd7629c83cb4eea534 /experiments
parent7aa7123e190cbae3f6ce55050666efcc2ce00796 (diff)
Add vanilla FA (Lillicrap 2016) implementation + full experiment suite
PAPER-CHANGING FINDING: FA is dramatically different from DFA on the same architecture. FA has genuine deep credit quality where DFA has none. Implementation: - experiments/cifar_resmlp.py: added train_fa() + FA diagnostic support FA uses sequential backward credit propagation with d×d random matrices (a_l = B_l @ a_{l+1}) instead of DFA's direct output-error projection (a_l = B_l^T @ e_T). Same local loss form <f_l, a_l>. Core results (A-H, 100ep 3-seed d=256 terminal-LN ResMLP): FA main audit: 0.401 ± 0.009 (DFA: 0.306 ± 0.008) +9.5 pp FA vs frozen: +5.2 pp ABOVE baseline (DFA: -4.3 pp below) FA deep cos: +0.33 (DFA: ~0 degenerate) FA ||h_L||: ~10^5 (DFA: ~5×10^8) 3 OOM less growth FA ||g_L||: ~10^-6 meaningful (DFA: ~10^-10 floor) Mode 1(b) fires: NO for FA; YES for DFA FA+pen lam=1e-2: 0.369 ± 0.003 (DFA+pen: 0.360 ± 0.002) FA+pen lam=1e-4: 0.377 ± 0.006 (DFA+pen lam=1e-4: 0.360) At lam=1e-4, FA already has deep cos +0.30 while DFA has -0.02 FA random-target: acc 0.12 (chance), h_L=1.3e5 (DFA: 1.7e8) FA early 5ep: deep cos already +0.32 (DFA ep1: -0.008) Extension results (d=512 depth sweep, 100ep, s42): L=2: FA 0.350, cos +0.96 (DFA: n/a) L=4: FA 0.424, cos +0.29 (DFA: n/a) L=6: FA 0.401, cos +0.16 (DFA: n/a) L=8: FA 0.409, cos +0.11 (DFA: 0.306, cos -0.0001) L=12: FA 0.404, cos +0.09 (DFA: 0.309, cos -0.0001) FA deep cos is positive at EVERY depth; DFA is ~0 everywhere. FA accuracy exceeds DFA by 5-10 pp at L=8 and L=12. This is the strongest empirical support for the Mode 2 → Mode 1 hypothesis: same local loss, same architecture, same optimizer — only the credit signal differs. FA's sequential propagation produces much better per-layer credit (cos +0.33 vs ~0), which prevents the catastrophic activation growth that DFA exhibits. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/cifar_resmlp.py142
1 files changed, 140 insertions, 2 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py
index 7aba671..05a355d 100644
--- a/experiments/cifar_resmlp.py
+++ b/experiments/cifar_resmlp.py
@@ -229,6 +229,116 @@ def train_dfa(model, train_loader, test_loader, device, args):
# =============================================================================
+# Vanilla FA (Lillicrap 2016)
+# =============================================================================
+def train_fa(model, train_loader, test_loader, device, args):
+ """
+ Vanilla Feedback Alignment (Lillicrap et al. 2016).
+ Unlike DFA (which projects output error directly to each layer via
+ a_l = B_l^T @ e_T), FA propagates credit sequentially backward through
+ the block stack using fixed random d×d feedback matrices:
+ a_L = exact gradient at h_L through out_head + out_ln
+ a_l = B_l @ a_{l+1} (random d×d replaces block Jacobian transpose)
+ Each block is updated with the same local loss as DFA: <f_l(h_l), a_l>.
+ """
+ d = model.d_hidden
+ num_classes = args.num_classes
+ L = model.num_blocks
+
+ # Fixed random feedback matrices: d × d (one per block).
+ # These replace the transpose of the block Jacobian dF_l/dh_l in the
+ # backward pass. Contrast with DFA's B_l which are d × num_classes.
+ Bs = [torch.randn(d, d, device=device) / np.sqrt(d) for _ in range(L)]
+
+ # Same optimizer structure as DFA
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ if getattr(args, 'random_targets', False):
+ y = torch.randint(0, args.num_classes, y.shape, device=device)
+ batch = x.size(0)
+
+ # Forward pass
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+
+ # 1. Update output head (exact CE gradient, h_L detached)
+ hL_det = hiddens[-1].detach().requires_grad_(True)
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # Exact gradient at h_L — FA's starting credit signal
+ a_credit = hL_det.grad.detach() # (batch, d)
+
+ # 2. Update each block with FA credit (backward sequential)
+ for l in range(L - 1, -1, -1):
+ h_l = hiddens[l].detach()
+ # Normalize credit
+ rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_credit / rms
+ # Local surrogate (same form as DFA)
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ if getattr(args, 'penalty_lam', 0.0) > 0.0:
+ local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step()
+
+ # Propagate credit backward: FA replaces block Jacobian^T with B_l
+ a_credit = (a_credit @ Bs[l]).detach()
+
+ # 3. Update embedding with FA credit at h_0
+ rms_0 = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_credit / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for s in all_schedulers:
+ s.step()
+
+ train_loss = total_loss / total
+ train_acc = correct / total
+ test_acc = evaluate(model, test_loader, device)
+ log['train_loss'].append(train_loss)
+ log['train_acc'].append(train_acc)
+ log['test_acc'].append(test_acc)
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [FA] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}")
+
+ return log, Bs
+
+
+# =============================================================================
# State Bridge
# =============================================================================
def train_state_bridge(model, train_loader, test_loader, device, args):
@@ -621,6 +731,18 @@ def compute_diagnostics(model, method_name, test_loader, device, args,
'bp_grad_norms_per_layer': bp_grad_norms_per_layer,
}
+ # Pre-compute FA credits if needed (sequential backward from exact h_L gradient)
+ _fa_credits = None
+ if method_name == 'fa' and dfa_Bs is not None:
+ hL_req = hiddens[L].detach().requires_grad_(True)
+ logits_fa = model.out_head(model.out_ln(hL_req))
+ loss_fa = F.cross_entropy(logits_fa, y, reduction='sum')
+ _fa_a_L = torch.autograd.grad(loss_fa, hL_req)[0].detach()
+ _fa_credits = [None] * L
+ _fa_credits[L - 1] = _fa_a_L
+ for ll in range(L - 2, -1, -1):
+ _fa_credits[ll] = (_fa_credits[ll + 1] @ dfa_Bs[ll + 1]).detach()
+
for l in range(L):
h_l = hiddens[l].detach()
t_l = torch.full((batch,), l / L, device=device)
@@ -630,6 +752,8 @@ def compute_diagnostics(model, method_name, test_loader, device, args,
a_l = bp_grads[l]
elif method_name == 'dfa':
a_l = (e_T @ dfa_Bs[l].T).detach()
+ elif method_name == 'fa':
+ a_l = _fa_credits[l]
elif method_name == 'state_bridge':
h_l_req = h_l.clone().requires_grad_(True)
pred_hL = state_predictor(h_l_req, t_l, s)
@@ -720,6 +844,20 @@ def run_experiment(args):
seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift}
print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}")
+ # ---- FA (vanilla Feedback Alignment, Lillicrap 2016) ----
+ if 'fa' in methods_to_run:
+ print("\n--- FA ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_fa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_fa = {n: p.clone().detach() for n, p in model_fa.named_parameters()}
+ fa_log, fa_Bs = train_fa(model_fa, train_loader, test_loader, device, args)
+ fa_diag = compute_diagnostics(model_fa, 'fa', test_loader, device, args, dfa_Bs=fa_Bs)
+ fa_drift = feature_drift(init_fa, {n: p.detach() for n, p in model_fa.named_parameters()})
+ seed_results['fa'] = {'log': fa_log, 'diagnostics': fa_diag, 'drift': fa_drift}
+ print(f" Final test acc: {fa_log['test_acc'][-1]:.4f}")
+
# ---- State Bridge ----
if 'state_bridge' in methods_to_run:
print("\n--- State Bridge ---")
@@ -793,8 +931,8 @@ def main():
parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456])
parser.add_argument('--gpu', type=int, default=1)
parser.add_argument('--output_dir', type=str, default='results/cifar10')
- parser.add_argument('--methods', type=str, nargs='+', default=['bp', 'dfa', 'state_bridge', 'credit_bridge'],
- help='Subset of methods to run.')
+ parser.add_argument('--methods', type=str, nargs='+', default=['bp', 'dfa', 'fa', 'state_bridge', 'credit_bridge'],
+ help='Subset of methods to run. fa = vanilla Feedback Alignment (Lillicrap 2016).')
parser.add_argument('--random_targets', action='store_true',
help='Replace each minibatch label with i.i.d. random class targets (Mode 1 data-agnostic test).')
parser.add_argument('--penalty_lam', type=float, default=0.0,