diff options
Diffstat (limited to 'experiments/snapshot_fa_studentnet.py')
| -rw-r--r-- | experiments/snapshot_fa_studentnet.py | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/experiments/snapshot_fa_studentnet.py b/experiments/snapshot_fa_studentnet.py new file mode 100644 index 0000000..887365c --- /dev/null +++ b/experiments/snapshot_fa_studentnet.py @@ -0,0 +1,94 @@ +"""FA-only snapshot evolution for StudentNet (synthetic teacher-student).""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from experiments.confirmatory_paper_experiments import ( + StudentNet, TeacherNet, generate_synth_dataset, set_seed +) +from experiments.snapshot_synth_residual_explosion import diagnose_synth + + +def train_fa_synth(model, train_loader, x_eval, y_eval, device, epochs, lr, wd): + """Canonical FA for StudentNet: mean reduction, grad before step, no clipping.""" + d_hidden = model.d_hidden + L = model.num_blocks + Bs = [torch.randn(d_hidden, d_hidden, device=device) / np.sqrt(d_hidden) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = [] + d0 = diagnose_synth(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0) + print(f" [FA] Ep 0: acc={d0['acc_eval']:.4f}", flush=True) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + # Head update — grad BEFORE step (old head) + hL_det = hiddens[-1].detach().requires_grad_(True) + logits_out = model.out_head(hL_det) + loss_out = F.cross_entropy(logits_out, y) # mean reduction + head_opt.zero_grad() + loss_out.backward() + a_credit = hL_det.grad.detach() + head_opt.step() + # Top-down block updates, propagate credit after each + for l in range(L - 1, -1, -1): + h_l = hiddens[l].detach() + rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](h_l) + local_loss = (f_l * (a_credit / rms)).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() # no clipping + a_credit = (a_credit @ Bs[l]).detach() + # No embed for StudentNet (input is already d_hidden) + for s in all_sch: s.step() + d = diagnose_synth(model, x_eval, y_eval); d['epoch'] = ep; log.append(d) + if ep % 5 == 0 or ep in (1, epochs): + print(f" [FA] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} " + f"||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} " + f"acc={d['acc_eval']:.4f}", flush=True) + return log + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--output', type=str, required=True) + p.add_argument('--epochs', type=int, default=80) + p.add_argument('--seed', type=int, default=42) + p.add_argument('--alpha', type=float, default=1.0) + p.add_argument('--depth', type=int, default=4) + p.add_argument('--d_hidden', type=int, default=128) + args = p.parse_args() + + device = torch.device('cuda:0') + L, d, C = args.depth, args.d_hidden, 10 + set_seed(args.seed) + teacher = TeacherNet(d, L, C, args.alpha, seed=0).to(device) + X_tr, Y_tr = generate_synth_dataset(teacher, 50*256, d, device, seed=args.seed) + X_te, Y_te = generate_synth_dataset(teacher, 2000, d, device, seed=args.seed+10000) + train_loader = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=256, shuffle=True) + + print(f"StudentNet FA: alpha={args.alpha}, L={L}, d={d}, seed={args.seed}", flush=True) + set_seed(args.seed) + model = StudentNet(d, C, L, args.alpha).to(device) + fa_log = train_fa_synth(model, train_loader, X_te.to(device), Y_te.to(device), + device, args.epochs, 1e-3, 0.01) + + with open(args.output, 'w') as f: + json.dump({'fa_log': fa_log, 'seed': args.seed, 'alpha': args.alpha, + 'depth': L, 'd_hidden': d}, f, indent=2) + print(f"Saved: {args.output}", flush=True) + + +if __name__ == '__main__': + main() |
