summaryrefslogtreecommitdiff
path: root/experiments/dfa_penalty_freshB.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/dfa_penalty_freshB.py')
-rw-r--r--experiments/dfa_penalty_freshB.py183
1 files changed, 183 insertions, 0 deletions
diff --git a/experiments/dfa_penalty_freshB.py b/experiments/dfa_penalty_freshB.py
new file mode 100644
index 0000000..82b192d
--- /dev/null
+++ b/experiments/dfa_penalty_freshB.py
@@ -0,0 +1,183 @@
+"""
+DFA canonical λ=1e-2 training + checkpoint save + fresh-B null calibration.
+Runs after the main penalty sweep to produce the null calibration on the canonical checkpoint.
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+from metrics.credit_metrics import cosine_similarity_batch
+
+
+def get_data(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))
+
+
+def train_dfa_canonical(model, train_loader, device, epochs, lr, wd, penalty_lam):
+ """Canonical DFA from cifar_resmlp.py: no grad clipping, mean reduction."""
+ d = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)])
+
+ for epoch in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a_dfa / rms)).sum(dim=-1).mean()
+ if penalty_lam > 0:
+ local_loss = local_loss + penalty_lam * (f_l ** 2).sum(dim=-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step()
+ a_0 = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ if epoch % 10 == 0 or epoch == epochs:
+ print(f" [DFA pen] ep {epoch}", flush=True)
+ return Bs
+
+
+def compute_deep_cosine(model, Bs, x_eval, y_eval, device):
+ """Compute per-layer DFA cosine on eval buffer."""
+ model.eval()
+ L = model.num_blocks
+ h0 = model.embed(x_eval.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ logits = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ with torch.no_grad():
+ e_T = logits.softmax(-1)
+ e_T[torch.arange(x_eval.size(0)), y_eval] -= 1
+ cos_per_layer = []
+ for l in range(L):
+ a_dfa = (e_T @ Bs[l].T).detach()
+ cos_per_layer.append(cosine_similarity_batch(a_dfa, grads[l].detach()))
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ g_norms = [g.norm(dim=-1).median().item() for g in grads]
+ h_norms = [h.detach().norm(dim=-1).median().item() for h in hs]
+ return cos_per_layer, acc, g_norms, h_norms
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--output_dir', type=str, default='results/dfa_canonical_freshB')
+ p.add_argument('--n_fresh', type=int, default=20)
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device('cuda:0')
+ train_loader, test_loader = get_data(128)
+
+ # Fixed eval buffer
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 128:
+ break
+ x_eval = torch.cat(xs)[:128].to(device)
+ y_eval = torch.cat(ys)[:128].to(device)
+
+ L, d, C = 4, 256, 10
+
+ # Train DFA with λ=1e-2
+ print(f"Training DFA canonical λ=0.01, seed={args.seed}", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ model = ResidualMLP(3072, d, C, L).to(device)
+ training_Bs = train_dfa_canonical(model, train_loader, device, 30, 1e-3, 0.01, 0.01)
+
+ # Save checkpoint
+ ckpt_path = os.path.join(args.output_dir, f'dfa_canonical_lam0.01_s{args.seed}.pt')
+ torch.save({'state_dict': model.state_dict(),
+ 'Bs': [B.cpu() for B in training_Bs],
+ 'seed': args.seed}, ckpt_path)
+ print(f"Saved checkpoint: {ckpt_path}", flush=True)
+
+ # Compute cosine with training Bs
+ cos_training, acc, g_norms, h_norms = compute_deep_cosine(model, training_Bs, x_eval, y_eval, device)
+ deep_cos_training = float(np.mean(cos_training[1:])) # exclude layer 0
+ print(f"Training-Bs: acc={acc:.4f}, deep cos={deep_cos_training:+.4f}")
+ print(f" per-layer cos: {[f'{c:+.4f}' for c in cos_training]}")
+ print(f" ||g_l||: {[f'{g:.2e}' for g in g_norms]}")
+ print(f" ||h_l||: {[f'{h:.2e}' for h in h_norms]}")
+
+ # Fresh-B null calibration
+ print(f"\nFresh-B null calibration ({args.n_fresh} draws)...", flush=True)
+ fresh_deep_cos = []
+ fresh_per_layer = []
+ for i in range(args.n_fresh):
+ fresh_Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ cos_fresh, _, _, _ = compute_deep_cosine(model, fresh_Bs, x_eval, y_eval, device)
+ deep_fresh = float(np.mean(cos_fresh[1:]))
+ fresh_deep_cos.append(deep_fresh)
+ fresh_per_layer.append(cos_fresh)
+ fresh_mean = np.mean(fresh_deep_cos)
+ fresh_std_ddof1 = np.std(fresh_deep_cos, ddof=1)
+ print(f"Fresh-Bs deep cos: {fresh_mean:+.4f} ± {fresh_std_ddof1:.4f} (ddof=1)")
+
+ # Save results
+ out = {
+ 'description': f'Canonical DFA λ=0.01 s={args.seed} + fresh-B null (N={args.n_fresh})',
+ 'training_Bs_deep_cos': deep_cos_training,
+ 'training_Bs_per_layer_cos': cos_training,
+ 'training_Bs_acc': acc,
+ 'training_Bs_g_norms': g_norms,
+ 'training_Bs_h_norms': h_norms,
+ 'fresh_Bs_n_draws': args.n_fresh,
+ 'fresh_Bs_deep_cos_per_draw': fresh_deep_cos,
+ 'fresh_Bs_deep_mean': fresh_mean,
+ 'fresh_Bs_deep_std_ddof1': fresh_std_ddof1,
+ 'fresh_Bs_per_layer_mean': [float(np.mean([fl[l] for fl in fresh_per_layer])) for l in range(L)],
+ }
+ out_path = os.path.join(args.output_dir, f'freshB_null_canonical_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(out, f, indent=2)
+ print(f"Saved: {out_path}", flush=True)
+
+
+if __name__ == '__main__':
+ main()