summaryrefslogtreecommitdiff
path: root/experiments/snapshot_fa_studentnet.py
blob: 887365cda913843c772fc3f77bac6e55170a84f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""FA-only snapshot evolution for StudentNet (synthetic teacher-student)."""
import os, sys, json, argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from experiments.confirmatory_paper_experiments import (
    StudentNet, TeacherNet, generate_synth_dataset, set_seed
)
from experiments.snapshot_synth_residual_explosion import diagnose_synth


def train_fa_synth(model, train_loader, x_eval, y_eval, device, epochs, lr, wd):
    """Canonical FA for StudentNet: mean reduction, grad before step, no clipping."""
    d_hidden = model.d_hidden
    L = model.num_blocks
    Bs = [torch.randn(d_hidden, d_hidden, device=device) / np.sqrt(d_hidden) for _ in range(L)]
    block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
    head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
    all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
              [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
    log = []
    d0 = diagnose_synth(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0)
    print(f"  [FA] Ep 0: acc={d0['acc_eval']:.4f}", flush=True)
    for ep in range(1, epochs + 1):
        model.train()
        for x, y in train_loader:
            x = x.to(device); y = y.to(device)
            with torch.no_grad():
                logits, hiddens = model(x, return_hidden=True)
            # Head update — grad BEFORE step (old head)
            hL_det = hiddens[-1].detach().requires_grad_(True)
            logits_out = model.out_head(hL_det)
            loss_out = F.cross_entropy(logits_out, y)  # mean reduction
            head_opt.zero_grad()
            loss_out.backward()
            a_credit = hL_det.grad.detach()
            head_opt.step()
            # Top-down block updates, propagate credit after each
            for l in range(L - 1, -1, -1):
                h_l = hiddens[l].detach()
                rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
                f_l = model.blocks[l](h_l)
                local_loss = (f_l * (a_credit / rms)).sum(dim=-1).mean()
                block_opts[l].zero_grad()
                local_loss.backward()
                block_opts[l].step()  # no clipping
                a_credit = (a_credit @ Bs[l]).detach()
            # No embed for StudentNet (input is already d_hidden)
        for s in all_sch: s.step()
        d = diagnose_synth(model, x_eval, y_eval); d['epoch'] = ep; log.append(d)
        if ep % 5 == 0 or ep in (1, epochs):
            print(f"  [FA] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} "
                  f"||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} "
                  f"acc={d['acc_eval']:.4f}", flush=True)
    return log


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--output', type=str, required=True)
    p.add_argument('--epochs', type=int, default=80)
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--alpha', type=float, default=1.0)
    p.add_argument('--depth', type=int, default=4)
    p.add_argument('--d_hidden', type=int, default=128)
    args = p.parse_args()

    device = torch.device('cuda:0')
    L, d, C = args.depth, args.d_hidden, 10
    set_seed(args.seed)
    teacher = TeacherNet(d, L, C, args.alpha, seed=0).to(device)
    X_tr, Y_tr = generate_synth_dataset(teacher, 50*256, d, device, seed=args.seed)
    X_te, Y_te = generate_synth_dataset(teacher, 2000, d, device, seed=args.seed+10000)
    train_loader = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=256, shuffle=True)

    print(f"StudentNet FA: alpha={args.alpha}, L={L}, d={d}, seed={args.seed}", flush=True)
    set_seed(args.seed)
    model = StudentNet(d, C, L, args.alpha).to(device)
    fa_log = train_fa_synth(model, train_loader, X_te.to(device), Y_te.to(device),
                            device, args.epochs, 1e-3, 0.01)

    with open(args.output, 'w') as f:
        json.dump({'fa_log': fa_log, 'seed': args.seed, 'alpha': args.alpha,
                   'depth': L, 'd_hidden': d}, f, indent=2)
    print(f"Saved: {args.output}", flush=True)


if __name__ == '__main__':
    main()