summaryrefslogtreecommitdiff
path: root/experiments/snapshot_evolution_residual_explosion.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/snapshot_evolution_residual_explosion.py')
-rw-r--r--experiments/snapshot_evolution_residual_explosion.py78
1 files changed, 78 insertions, 0 deletions
diff --git a/experiments/snapshot_evolution_residual_explosion.py b/experiments/snapshot_evolution_residual_explosion.py
index 1dc09f2..6155d94 100644
--- a/experiments/snapshot_evolution_residual_explosion.py
+++ b/experiments/snapshot_evolution_residual_explosion.py
@@ -212,6 +212,73 @@ def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_e
return log
+def train_fa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1):
+ """FA (Lillicrap 2016): sequential backward credit with d×d random matrices.
+ Canonical implementation matching cifar_resmlp.py train_fa():
+ - mean reduction (default)
+ - gradient taken BEFORE head step (old head weights)
+ - top-down block update, credit propagated after each block
+ - NO grad clipping
+ """
+ d_hidden = model.d_hidden
+ L = model.num_blocks
+ Bs_fa = [torch.randn(d_hidden, d_hidden, device=device) / np.sqrt(d_hidden) for _ in range(L)]
+ block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)])
+ log = []
+ d0 = diagnose(model, x_eval, y_eval)
+ d0['epoch'] = 0
+ log.append(d0)
+ print(f" [FA] Ep 0: ||h||_med={d0['hidden_norms']} acc={d0['acc_eval']:.4f}", flush=True)
+ for epoch in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ # Forward
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ # Head update — get gradient BEFORE step (old head weights)
+ hL_det = hiddens[-1].detach().requires_grad_(True)
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y) # mean reduction
+ head_opt.zero_grad()
+ loss_out.backward()
+ a_credit = hL_det.grad.detach() # gradient w.r.t. old head
+ head_opt.step()
+ # Top-down block updates with sequential FA credit propagation
+ for l in range(L - 1, -1, -1):
+ h_l = hiddens[l].detach()
+ rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_credit / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step() # no grad clipping
+ a_credit = (a_credit @ Bs_fa[l]).detach()
+ # Embed update with final propagated credit
+ rms_0 = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_credit / rms_0)).sum(dim=-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch:
+ s.step()
+ if epoch % log_every == 0 or epoch == epochs:
+ d = diagnose(model, x_eval, y_eval)
+ d['epoch'] = epoch
+ log.append(d)
+ print(f" [FA] Ep {epoch}: ||h_L||={d['hidden_norms'][-1]:.3e} "
+ f"||g_2||={d['bp_grad_norms_per_sample_med'][2]:.3e} "
+ f"acc={d['acc_eval']:.4f}", flush=True)
+ return log
+
+
def main():
p = argparse.ArgumentParser()
p.add_argument('--output_dir', type=str, default='results/snapshot_evolution_v2')
@@ -262,11 +329,22 @@ def main():
args.epochs, args.lr, args.wd, log_every=args.log_every,
random_targets=args.random_targets)
+ fa_log = None
+ if not args.skip_bp and not args.random_targets: # FA only when doing full run
+ print("\n=== FA training ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ fa_model = ResidualMLP(3072, d, C, L,
+ residual_add=not args.no_residual_add,
+ w2_std=args.w2_std).to(device)
+ fa_log = train_fa(fa_model, train_loader, x_eval, y_eval, device,
+ args.epochs, args.lr, args.wd, log_every=args.log_every)
+
out = {
'config': vars(args),
'depth': L, 'd_hidden': d, 'num_classes': C,
'bp_log': bp_log,
'dfa_log': dfa_log,
+ 'fa_log': fa_log,
}
out_path = os.path.join(args.output_dir, f'snapshot_evolution_s{args.seed}.json')
with open(out_path, 'w') as f: