summaryrefslogtreecommitdiff
path: root/experiments/snapshot_fa_only.py
blob: cdc69aea6689479bbc43e245d846b5bbb1455d96 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""Quick FA-only snapshot evolution. Reuses the full script's train_fa + diagnose."""
import os, sys, json, argparse
import numpy as np
import torch
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from experiments.snapshot_evolution_residual_explosion import (
    get_cifar10, fixed_eval_buffer, train_fa
)
from models.residual_mlp import ResidualMLP

def main():
    p = argparse.ArgumentParser()
    p.add_argument('--output', type=str, required=True)
    p.add_argument('--epochs', type=int, default=100)
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--depth', type=int, default=4)
    p.add_argument('--d_hidden', type=int, default=256)
    args = p.parse_args()

    device = torch.device('cuda:0')
    train_loader, test_loader = get_cifar10(batch_size=128)
    x_eval, y_eval = fixed_eval_buffer(test_loader, device, n_samples=1024)

    L, d, C = args.depth, args.d_hidden, 10
    print(f"FA snapshot: depth={L}, d={d}, seed={args.seed}, epochs={args.epochs}", flush=True)

    torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
    model = ResidualMLP(3072, d, C, L).to(device)
    fa_log = train_fa(model, train_loader, x_eval, y_eval, device,
                      args.epochs, 1e-3, 0.01, log_every=1)

    with open(args.output, 'w') as f:
        json.dump({'fa_log': fa_log, 'seed': args.seed, 'depth': L, 'd_hidden': d}, f, indent=2)
    print(f"Saved: {args.output}", flush=True)

if __name__ == '__main__':
    main()