diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 01:17:47 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 01:17:47 -0500 |
| commit | 41ace4c1d99a7a8436e42710135d44b925920850 (patch) | |
| tree | 1142e515ed6835dd1a75da8fba6e5bc49e098a46 /experiments | |
| parent | 671d9823668197c21b2d35d08d15da0d5c3c4161 (diff) | |
Add vanilla DFA early-epoch checkpoint training (round 19 disambiguation)
Trains vanilla DFA (no penalty) for max_epoch epochs and saves checkpoints
+ Bs at specified early epochs (default: 1, 2, 3, 4, 5). Logs per-layer
||h_l|| and ||g_l|| at each epoch so we can see when ||g_L|| crosses the
1e-7 floor.
Codex round 19's #3 critical experiment for disambiguating:
Hypothesis A: deep alignment was always there in vanilla DFA but hidden
by the post-collapse measurement degeneracy
Hypothesis B: deep alignment was created by the penalty intervention
Test: measure deep-layer cos at vanilla checkpoints from ep 1-3 (when
||g_L|| should still be in the meaningful regime).
If cos > 0 at ep 1-2 vanilla -> hypothesis A
If cos ~ 0 at ep 1-2 vanilla -> hypothesis B
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/vanilla_dfa_early_ckpt.py | 179 |
1 files changed, 179 insertions, 0 deletions
diff --git a/experiments/vanilla_dfa_early_ckpt.py b/experiments/vanilla_dfa_early_ckpt.py new file mode 100644 index 0000000..cf69586 --- /dev/null +++ b/experiments/vanilla_dfa_early_ckpt.py @@ -0,0 +1,179 @@ +""" +Train vanilla DFA (no penalty) on the standard 4-block d=256 ResMLP and +save checkpoints at the early epochs (1, 2, 3) BEFORE ‖g_L‖ has +collapsed to the numerical floor. + +Codex round 19's #3 priority experiment to disambiguate: + - Hypothesis A: deep-layer alignment was always present in vanilla DFA but + hidden by the post-collapse measurement degeneracy. Penalty just made + the measurement interpretable. + - Hypothesis B: deep-layer alignment was created by the penalty + intervention. Vanilla DFA at any epoch has zero deep alignment. + +Test: measure deep-layer cos at vanilla checkpoints from ep 1, 2, 3 (when +‖g_L‖ should still be in the meaningful regime). + +Run: + CUDA_VISIBLE_DEVICES=2 python experiments/vanilla_dfa_early_ckpt.py --seed 42 +""" +import os +import sys +import argparse +import json + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP + + +def get_loaders(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return ( + DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), + ) + + +def evaluate(model, loader, dev): + model.eval() + n = c = 0 + with torch.no_grad(): + for x, y in loader: + x = x.view(x.size(0), -1).to(dev); y = y.to(dev) + preds = model(x).argmax(-1) + c += (preds == y).sum().item() + n += x.size(0) + return c / n + + +def diagnose_norms(model, x_eval, y_eval, dev): + model.eval() + with torch.no_grad(): + _, hi = model(x_eval, return_hidden=True) + h_norms = [h.norm(dim=-1).median().item() for h in hi] + h0 = model.embed(x_eval.detach()) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: + hs.append(hs[-1] + b(hs[-1])) + lo = model.out_head(model.out_ln(hs[-1])) + loss = F.cross_entropy(lo, y_eval) + gs = torch.autograd.grad(loss, hs) + g_norms = [g.norm(dim=-1).median().item() for g in gs] + return h_norms, g_norms + + +def train_vanilla_dfa(model, train_loader, dev, max_epoch, lr, wd, Bs, x_eval, y_eval, save_at, output_dir, seed): + L = model.num_blocks + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd + ) + log = [] + h0_norms, g0_norms = diagnose_norms(model, x_eval, y_eval, dev) + log.append({"epoch": 0, "h_norms": h0_norms, "g_norms": g0_norms}) + print(f" ep 0: h_norms={[f'{h:.2e}' for h in h0_norms]}, g_norms={[f'{g:.2e}' for g in g0_norms]}", flush=True) + + for ep in range(1, max_epoch + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(dev); y = y.to(dev) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + hL_det = hiddens[-1].detach() + head_opt.zero_grad() + F.cross_entropy(model.out_head(model.out_ln(hL_det)), y).backward() + head_opt.step() + for l in range(L): + h_l = hiddens[l].detach() + a = (e_T @ Bs[l].T).detach() + rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + f = model.blocks[l](h_l) + loss = (f * (a / rms)).sum(-1).mean() + block_opts[l].zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0_emb = model.embed(x) + embed_opt.zero_grad() + (h0_emb * (a0 / rms0)).sum(-1).mean().backward() + embed_opt.step() + h_norms, g_norms = diagnose_norms(model, x_eval, y_eval, dev) + log.append({"epoch": ep, "h_norms": h_norms, "g_norms": g_norms}) + print(f" ep {ep}: h_norms={[f'{h:.2e}' for h in h_norms]}, g_norms={[f'{g:.2e}' for g in g_norms]}", flush=True) + if ep in save_at: + ckpt_path = os.path.join(output_dir, f"vanilla_dfa_s{seed}_ep{ep}.pt") + torch.save({ + "state_dict": model.state_dict(), + "Bs": [b.cpu() for b in Bs], + "epoch": ep, + "h_norms": h_norms, + "g_norms": g_norms, + }, ckpt_path) + print(f" saved {ckpt_path}", flush=True) + return log + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--seed", type=int, default=42) + p.add_argument("--max_epoch", type=int, default=5) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--wd", type=float, default=0.01) + p.add_argument("--save_at", type=int, nargs="+", default=[1, 2, 3, 4, 5]) + p.add_argument("--output_dir", type=str, default="results/vanilla_dfa_early_ckpts") + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + dev = torch.device("cuda:0") + print(f"Vanilla DFA early-epoch checkpoint sweep: seed={args.seed}, max_epoch={args.max_epoch}", flush=True) + train_loader, test_loader = get_loaders(batch_size=128) + + # Eval batch + xs, ys = [], [] + for x, y in test_loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 1024: + break + x_eval = torch.cat(xs)[:1024].to(dev) + y_eval = torch.cat(ys)[:1024].to(dev) + + L, d, C = 4, 256, 10 + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = ResidualMLP(3072, d, C, L).to(dev) + Bs = [torch.randn(d, C, device=dev) / np.sqrt(C) for _ in range(L)] + log = train_vanilla_dfa(m, train_loader, dev, args.max_epoch, args.lr, args.wd, Bs, x_eval, y_eval, args.save_at, args.output_dir, args.seed) + + out = {"config": vars(args), "log": log} + out_path = os.path.join(args.output_dir, f"vanilla_dfa_s{args.seed}_log.json") + with open(out_path, "w") as f: + json.dump(out, f, indent=2) + print(f"Saved {out_path}") + + +if __name__ == "__main__": + main() |
