From aa73718eb6427d7da3b9cb416275802d90c4b2ed Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sun, 14 Jun 2026 04:06:32 -0500 Subject: Add new experiment scripts, figures, and paper assets; untrack pyc/build artifacts Co-Authored-By: Claude Opus 4.8 (1M context) --- experiments/snapshot_fa_only.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 experiments/snapshot_fa_only.py (limited to 'experiments/snapshot_fa_only.py') diff --git a/experiments/snapshot_fa_only.py b/experiments/snapshot_fa_only.py new file mode 100644 index 0000000..cdc69ae --- /dev/null +++ b/experiments/snapshot_fa_only.py @@ -0,0 +1,38 @@ +"""Quick FA-only snapshot evolution. Reuses the full script's train_fa + diagnose.""" +import os, sys, json, argparse +import numpy as np +import torch +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from experiments.snapshot_evolution_residual_explosion import ( + get_cifar10, fixed_eval_buffer, train_fa +) +from models.residual_mlp import ResidualMLP + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--output', type=str, required=True) + p.add_argument('--epochs', type=int, default=100) + p.add_argument('--seed', type=int, default=42) + p.add_argument('--depth', type=int, default=4) + p.add_argument('--d_hidden', type=int, default=256) + args = p.parse_args() + + device = torch.device('cuda:0') + train_loader, test_loader = get_cifar10(batch_size=128) + x_eval, y_eval = fixed_eval_buffer(test_loader, device, n_samples=1024) + + L, d, C = args.depth, args.d_hidden, 10 + print(f"FA snapshot: depth={L}, d={d}, seed={args.seed}, epochs={args.epochs}", flush=True) + + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + model = ResidualMLP(3072, d, C, L).to(device) + fa_log = train_fa(model, train_loader, x_eval, y_eval, device, + args.epochs, 1e-3, 0.01, log_every=1) + + with open(args.output, 'w') as f: + json.dump({'fa_log': fa_log, 'seed': args.seed, 'depth': L, 'd_hidden': d}, f, indent=2) + print(f"Saved: {args.output}", flush=True) + +if __name__ == '__main__': + main() -- cgit v1.2.3