summaryrefslogtreecommitdiff
path: root/experiments/snapshot_evolution_residual_explosion.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/snapshot_evolution_residual_explosion.py')
-rw-r--r--experiments/snapshot_evolution_residual_explosion.py29
1 files changed, 20 insertions, 9 deletions
diff --git a/experiments/snapshot_evolution_residual_explosion.py b/experiments/snapshot_evolution_residual_explosion.py
index 86de4a4..1dc09f2 100644
--- a/experiments/snapshot_evolution_residual_explosion.py
+++ b/experiments/snapshot_evolution_residual_explosion.py
@@ -150,7 +150,8 @@ def train_bp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_ev
return log
-def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1):
+def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1,
+ random_targets: bool = False):
d_hidden = model.d_hidden
L = model.num_blocks
C = 10
@@ -172,6 +173,9 @@ def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_e
for x, y in train_loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
+ if random_targets:
+ # iid random class targets refreshed every minibatch (codex round 34 sharper variant)
+ y = torch.randint(0, 10, y.shape, device=device)
batch = x.size(0)
with torch.no_grad():
logits, hiddens = model(x, return_hidden=True)
@@ -222,6 +226,10 @@ def main():
help='Replace h = h + f with h = f (non-residual stack of LN-W1-GELU-W2 blocks).')
p.add_argument('--w2_std', type=float, default=0.01,
help='Init std for w2 in each block. Bump to 0.05 for non-residual stack.')
+ p.add_argument('--random_targets', action='store_true',
+ help='Replace each minibatch label with iid random class targets (codex round 34 OPTION A).')
+ p.add_argument('--skip_bp', action='store_true',
+ help='Only train DFA, skip BP. Useful for cheap DFA-only ablations.')
args = p.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
@@ -235,13 +243,15 @@ def main():
L, d, C = args.depth, args.d_hidden, 10
- print("\n=== BP training ===", flush=True)
- torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
- bp_model = ResidualMLP(3072, d, C, L,
- residual_add=not args.no_residual_add,
- w2_std=args.w2_std).to(device)
- bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device,
- args.epochs, args.lr, args.wd, log_every=args.log_every)
+ bp_log = None
+ if not args.skip_bp:
+ print("\n=== BP training ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ bp_model = ResidualMLP(3072, d, C, L,
+ residual_add=not args.no_residual_add,
+ w2_std=args.w2_std).to(device)
+ bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device,
+ args.epochs, args.lr, args.wd, log_every=args.log_every)
print("\n=== DFA training ===", flush=True)
torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
@@ -249,7 +259,8 @@ def main():
residual_add=not args.no_residual_add,
w2_std=args.w2_std).to(device)
dfa_log = train_dfa(dfa_model, train_loader, x_eval, y_eval, device,
- args.epochs, args.lr, args.wd, log_every=args.log_every)
+ args.epochs, args.lr, args.wd, log_every=args.log_every,
+ random_targets=args.random_targets)
out = {
'config': vars(args),