diff options
Diffstat (limited to 'experiments/snapshot_evolution_residual_explosion.py')
| -rw-r--r-- | experiments/snapshot_evolution_residual_explosion.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/experiments/snapshot_evolution_residual_explosion.py b/experiments/snapshot_evolution_residual_explosion.py index 1dc09f2..6155d94 100644 --- a/experiments/snapshot_evolution_residual_explosion.py +++ b/experiments/snapshot_evolution_residual_explosion.py @@ -212,6 +212,73 @@ def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_e return log +def train_fa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1): + """FA (Lillicrap 2016): sequential backward credit with d×d random matrices. + Canonical implementation matching cifar_resmlp.py train_fa(): + - mean reduction (default) + - gradient taken BEFORE head step (old head weights) + - top-down block update, credit propagated after each block + - NO grad clipping + """ + d_hidden = model.d_hidden + L = model.num_blocks + Bs_fa = [torch.randn(d_hidden, d_hidden, device=device) / np.sqrt(d_hidden) for _ in range(L)] + block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd) + all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) + log = [] + d0 = diagnose(model, x_eval, y_eval) + d0['epoch'] = 0 + log.append(d0) + print(f" [FA] Ep 0: ||h||_med={d0['hidden_norms']} acc={d0['acc_eval']:.4f}", flush=True) + for epoch in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + # Forward + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + # Head update — get gradient BEFORE step (old head weights) + hL_det = hiddens[-1].detach().requires_grad_(True) + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) # mean reduction + head_opt.zero_grad() + loss_out.backward() + a_credit = hL_det.grad.detach() # gradient w.r.t. old head + head_opt.step() + # Top-down block updates with sequential FA credit propagation + for l in range(L - 1, -1, -1): + h_l = hiddens[l].detach() + rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_credit / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() # no grad clipping + a_credit = (a_credit @ Bs_fa[l]).detach() + # Embed update with final propagated credit + rms_0 = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_loss = (h0 * (a_credit / rms_0)).sum(dim=-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: + s.step() + if epoch % log_every == 0 or epoch == epochs: + d = diagnose(model, x_eval, y_eval) + d['epoch'] = epoch + log.append(d) + print(f" [FA] Ep {epoch}: ||h_L||={d['hidden_norms'][-1]:.3e} " + f"||g_2||={d['bp_grad_norms_per_sample_med'][2]:.3e} " + f"acc={d['acc_eval']:.4f}", flush=True) + return log + + def main(): p = argparse.ArgumentParser() p.add_argument('--output_dir', type=str, default='results/snapshot_evolution_v2') @@ -262,11 +329,22 @@ def main(): args.epochs, args.lr, args.wd, log_every=args.log_every, random_targets=args.random_targets) + fa_log = None + if not args.skip_bp and not args.random_targets: # FA only when doing full run + print("\n=== FA training ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + fa_model = ResidualMLP(3072, d, C, L, + residual_add=not args.no_residual_add, + w2_std=args.w2_std).to(device) + fa_log = train_fa(fa_model, train_loader, x_eval, y_eval, device, + args.epochs, args.lr, args.wd, log_every=args.log_every) + out = { 'config': vars(args), 'depth': L, 'd_hidden': d, 'num_classes': C, 'bp_log': bp_log, 'dfa_log': dfa_log, + 'fa_log': fa_log, } out_path = os.path.join(args.output_dir, f'snapshot_evolution_s{args.seed}.json') with open(out_path, 'w') as f: |
