diff options
Diffstat (limited to 'experiments/snapshot_evolution_residual_explosion.py')
| -rw-r--r-- | experiments/snapshot_evolution_residual_explosion.py | 29 |
1 files changed, 20 insertions, 9 deletions
diff --git a/experiments/snapshot_evolution_residual_explosion.py b/experiments/snapshot_evolution_residual_explosion.py index 86de4a4..1dc09f2 100644 --- a/experiments/snapshot_evolution_residual_explosion.py +++ b/experiments/snapshot_evolution_residual_explosion.py @@ -150,7 +150,8 @@ def train_bp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_ev return log -def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1): +def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1, + random_targets: bool = False): d_hidden = model.d_hidden L = model.num_blocks C = 10 @@ -172,6 +173,9 @@ def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_e for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) + if random_targets: + # iid random class targets refreshed every minibatch (codex round 34 sharper variant) + y = torch.randint(0, 10, y.shape, device=device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) @@ -222,6 +226,10 @@ def main(): help='Replace h = h + f with h = f (non-residual stack of LN-W1-GELU-W2 blocks).') p.add_argument('--w2_std', type=float, default=0.01, help='Init std for w2 in each block. Bump to 0.05 for non-residual stack.') + p.add_argument('--random_targets', action='store_true', + help='Replace each minibatch label with iid random class targets (codex round 34 OPTION A).') + p.add_argument('--skip_bp', action='store_true', + help='Only train DFA, skip BP. Useful for cheap DFA-only ablations.') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) @@ -235,13 +243,15 @@ def main(): L, d, C = args.depth, args.d_hidden, 10 - print("\n=== BP training ===", flush=True) - torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) - bp_model = ResidualMLP(3072, d, C, L, - residual_add=not args.no_residual_add, - w2_std=args.w2_std).to(device) - bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device, - args.epochs, args.lr, args.wd, log_every=args.log_every) + bp_log = None + if not args.skip_bp: + print("\n=== BP training ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + bp_model = ResidualMLP(3072, d, C, L, + residual_add=not args.no_residual_add, + w2_std=args.w2_std).to(device) + bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device, + args.epochs, args.lr, args.wd, log_every=args.log_every) print("\n=== DFA training ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) @@ -249,7 +259,8 @@ def main(): residual_add=not args.no_residual_add, w2_std=args.w2_std).to(device) dfa_log = train_dfa(dfa_model, train_loader, x_eval, y_eval, device, - args.epochs, args.lr, args.wd, log_every=args.log_every) + args.epochs, args.lr, args.wd, log_every=args.log_every, + random_targets=args.random_targets) out = { 'config': vars(args), |
