summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 01:17:47 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 01:17:47 -0500
commit41ace4c1d99a7a8436e42710135d44b925920850 (patch)
tree1142e515ed6835dd1a75da8fba6e5bc49e098a46 /experiments
parent671d9823668197c21b2d35d08d15da0d5c3c4161 (diff)
Add vanilla DFA early-epoch checkpoint training (round 19 disambiguation)
Trains vanilla DFA (no penalty) for max_epoch epochs and saves checkpoints + Bs at specified early epochs (default: 1, 2, 3, 4, 5). Logs per-layer ||h_l|| and ||g_l|| at each epoch so we can see when ||g_L|| crosses the 1e-7 floor. Codex round 19's #3 critical experiment for disambiguating: Hypothesis A: deep alignment was always there in vanilla DFA but hidden by the post-collapse measurement degeneracy Hypothesis B: deep alignment was created by the penalty intervention Test: measure deep-layer cos at vanilla checkpoints from ep 1-3 (when ||g_L|| should still be in the meaningful regime). If cos > 0 at ep 1-2 vanilla -> hypothesis A If cos ~ 0 at ep 1-2 vanilla -> hypothesis B
Diffstat (limited to 'experiments')
-rw-r--r--experiments/vanilla_dfa_early_ckpt.py179
1 files changed, 179 insertions, 0 deletions
diff --git a/experiments/vanilla_dfa_early_ckpt.py b/experiments/vanilla_dfa_early_ckpt.py
new file mode 100644
index 0000000..cf69586
--- /dev/null
+++ b/experiments/vanilla_dfa_early_ckpt.py
@@ -0,0 +1,179 @@
+"""
+Train vanilla DFA (no penalty) on the standard 4-block d=256 ResMLP and
+save checkpoints at the early epochs (1, 2, 3) BEFORE ‖g_L‖ has
+collapsed to the numerical floor.
+
+Codex round 19's #3 priority experiment to disambiguate:
+ - Hypothesis A: deep-layer alignment was always present in vanilla DFA but
+ hidden by the post-collapse measurement degeneracy. Penalty just made
+ the measurement interpretable.
+ - Hypothesis B: deep-layer alignment was created by the penalty
+ intervention. Vanilla DFA at any epoch has zero deep alignment.
+
+Test: measure deep-layer cos at vanilla checkpoints from ep 1, 2, 3 (when
+‖g_L‖ should still be in the meaningful regime).
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python experiments/vanilla_dfa_early_ckpt.py --seed 42
+"""
+import os
+import sys
+import argparse
+import json
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+
+
+def get_loaders(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ n = c = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def diagnose_norms(model, x_eval, y_eval, dev):
+ model.eval()
+ with torch.no_grad():
+ _, hi = model(x_eval, return_hidden=True)
+ h_norms = [h.norm(dim=-1).median().item() for h in hi]
+ h0 = model.embed(x_eval.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ lo = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(lo, y_eval)
+ gs = torch.autograd.grad(loss, hs)
+ g_norms = [g.norm(dim=-1).median().item() for g in gs]
+ return h_norms, g_norms
+
+
+def train_vanilla_dfa(model, train_loader, dev, max_epoch, lr, wd, Bs, x_eval, y_eval, save_at, output_dir, seed):
+ L = model.num_blocks
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd
+ )
+ log = []
+ h0_norms, g0_norms = diagnose_norms(model, x_eval, y_eval, dev)
+ log.append({"epoch": 0, "h_norms": h0_norms, "g_norms": g0_norms})
+ print(f" ep 0: h_norms={[f'{h:.2e}' for h in h0_norms]}, g_norms={[f'{g:.2e}' for g in g0_norms]}", flush=True)
+
+ for ep in range(1, max_epoch + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(model.out_ln(hL_det)), y).backward()
+ head_opt.step()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = (e_T @ Bs[l].T).detach()
+ rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f = model.blocks[l](h_l)
+ loss = (f * (a / rms)).sum(-1).mean()
+ block_opts[l].zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0_emb = model.embed(x)
+ embed_opt.zero_grad()
+ (h0_emb * (a0 / rms0)).sum(-1).mean().backward()
+ embed_opt.step()
+ h_norms, g_norms = diagnose_norms(model, x_eval, y_eval, dev)
+ log.append({"epoch": ep, "h_norms": h_norms, "g_norms": g_norms})
+ print(f" ep {ep}: h_norms={[f'{h:.2e}' for h in h_norms]}, g_norms={[f'{g:.2e}' for g in g_norms]}", flush=True)
+ if ep in save_at:
+ ckpt_path = os.path.join(output_dir, f"vanilla_dfa_s{seed}_ep{ep}.pt")
+ torch.save({
+ "state_dict": model.state_dict(),
+ "Bs": [b.cpu() for b in Bs],
+ "epoch": ep,
+ "h_norms": h_norms,
+ "g_norms": g_norms,
+ }, ckpt_path)
+ print(f" saved {ckpt_path}", flush=True)
+ return log
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--max_epoch", type=int, default=5)
+ p.add_argument("--lr", type=float, default=1e-3)
+ p.add_argument("--wd", type=float, default=0.01)
+ p.add_argument("--save_at", type=int, nargs="+", default=[1, 2, 3, 4, 5])
+ p.add_argument("--output_dir", type=str, default="results/vanilla_dfa_early_ckpts")
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ dev = torch.device("cuda:0")
+ print(f"Vanilla DFA early-epoch checkpoint sweep: seed={args.seed}, max_epoch={args.max_epoch}", flush=True)
+ train_loader, test_loader = get_loaders(batch_size=128)
+
+ # Eval batch
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 1024:
+ break
+ x_eval = torch.cat(xs)[:1024].to(dev)
+ y_eval = torch.cat(ys)[:1024].to(dev)
+
+ L, d, C = 4, 256, 10
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = ResidualMLP(3072, d, C, L).to(dev)
+ Bs = [torch.randn(d, C, device=dev) / np.sqrt(C) for _ in range(L)]
+ log = train_vanilla_dfa(m, train_loader, dev, args.max_epoch, args.lr, args.wd, Bs, x_eval, y_eval, args.save_at, args.output_dir, args.seed)
+
+ out = {"config": vars(args), "log": log}
+ out_path = os.path.join(args.output_dir, f"vanilla_dfa_s{args.seed}_log.json")
+ with open(out_path, "w") as f:
+ json.dump(out, f, indent=2)
+ print(f"Saved {out_path}")
+
+
+if __name__ == "__main__":
+ main()