summaryrefslogtreecommitdiff
path: root/experiments/cifar_online_vector_credit.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-24 18:03:55 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-24 18:03:55 -0500
commit5550e2cac45758e579810ae36bf716a0b819cebc (patch)
tree28f263e4030d6d5144af5badcebd533b27f4da78 /experiments/cifar_online_vector_credit.py
parent3d17cbad98f320905c52509c7f18691eab8bf2a0 (diff)
Add Phase 5: vector field audit, frozen CIFAR transfer, online pilot
Phase 5A: Audit passes — shuffle control collapses, gains are real Phase 5B: Transfer SUCCESS — vec_M4 beats scalar CB by +0.25 Gamma, +0.31 rho on frozen CIFAR Phase 5C: Online FAILURE — vec does worse than scalar CB online despite better frozen credit Core finding: bottleneck is in local surrogate / co-adaptation, not estimator quality Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments/cifar_online_vector_credit.py')
-rw-r--r--experiments/cifar_online_vector_credit.py404
1 files changed, 404 insertions, 0 deletions
diff --git a/experiments/cifar_online_vector_credit.py b/experiments/cifar_online_vector_credit.py
new file mode 100644
index 0000000..3a3762c
--- /dev/null
+++ b/experiments/cifar_online_vector_credit.py
@@ -0,0 +1,404 @@
+"""
+Phase 5C: Online Shallow CIFAR Vector Credit Pilot.
+
+Minimal pilot: does vector field's frozen credit gain translate to online training?
+
+Compare DFA, ScalarCB_eT, VectorField_eT_M4 on CIFAR-10, L=4, d=256.
+Sweep warmup_ratio and term_weight.
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test
+)
+
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def evaluate(model, test_loader, device):
+ model.eval()
+ correct, total = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def train_dfa(model, train_loader, test_loader, device, epochs, lr, wd):
+ d = model.d_hidden
+ L = model.num_blocks
+ Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = {'train_loss': [], 'test_acc': []}
+ for epoch in range(1, epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ hL = hiddens[-1].detach()
+ loss_out = F.cross_entropy(model.out_head(model.out_ln(hL)), y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ for l in range(L):
+ a = (e_T @ Bs[l].T).detach()
+ rms = (a**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f = model.blocks[l](hiddens[l].detach())
+ ll = (f * (a/rms)).sum(-1).mean()
+ block_opts[l].zero_grad(); ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ el = (model.embed(x) * (a0/rms0)).sum(-1).mean()
+ embed_opt.zero_grad(); el.backward(); embed_opt.step()
+ total_loss += loss_val.item() * batch; correct += (logits.argmax(1) == y).sum().item(); total += batch
+ for s in scheds: s.step()
+ test_acc = evaluate(model, test_loader, device)
+ log['train_loss'].append(total_loss/total); log['test_acc'].append(test_acc)
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [DFA] Ep {epoch}: loss={total_loss/total:.4f}, test={test_acc:.4f}")
+ return log, Bs
+
+
+def train_vector_online(model, train_loader, test_loader, device, epochs, lr, lr_fb, wd,
+ M=4, warmup_ratio=0.2, term_weight=1.0, eps=1e-3, beta=1.0):
+ d = model.d_hidden
+ L = model.num_blocks
+ warmup_epochs = max(1, int(epochs * warmup_ratio))
+
+ vector_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd)
+ vec_opt = optim.Adam(vector_net.parameters(), lr=lr_fb)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+
+ log = {'train_loss': [], 'test_acc': [], 'vloss': []}
+
+ for epoch in range(1, epochs + 1):
+ model.train(); vector_net.train()
+ credit_blend = 0.0 if epoch <= warmup_epochs else min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+ total_loss, correct, total, total_vloss = 0, 0, 0, 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ hL = hiddens[-1].detach()
+
+ # Train vector net: terminal matching
+ loss_term = torch.tensor(0.0, device=device)
+ if term_weight > 0:
+ t_L = torch.ones(batch, device=device)
+ a_term = vector_net(hL, t_L, s)
+ hL_req = hL.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req))
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_term - delta_L) ** 2).sum(-1).mean()
+
+ # Perturbation target: subsample 1 layer
+ l_train = np.random.randint(0, L)
+ h_l = hiddens[l_train].detach()
+ t_l = torch.full((batch,), l_train / L, device=device)
+ a_l = vector_net(h_l, t_l, s)
+
+ loss_proj = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l)
+ v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(model.forward_from_layer(h_l + eps*v, l_train), y, reduction='none')
+ lm = F.cross_entropy(model.forward_from_layer(h_l - eps*v, l_train), y, reduction='none')
+ g_j = (lp - lm) / (2*eps)
+ loss_proj = loss_proj + (((a_l * v).sum(-1) - g_j.detach())**2).mean()
+ loss_proj = loss_proj / M
+
+ vloss = term_weight * loss_term + beta * loss_proj
+ vec_opt.zero_grad(); vloss.backward()
+ torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0)
+ vec_opt.step()
+ total_vloss += vloss.item() * batch
+
+ # Compute credits
+ with torch.no_grad():
+ vec_credits = [vector_net(hiddens[l].detach(),
+ torch.full((batch,), l/L, device=device), s).detach() for l in range(L)]
+ dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ credits.append(vec_credits[l])
+ elif credit_blend <= 0.0:
+ credits.append(dfa_credits[l])
+ else:
+ vr = (vec_credits[l]**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ dr = (dfa_credits[l]**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ credits.append(credit_blend * vec_credits[l]/vr + (1-credit_blend) * dfa_credits[l]/dr)
+
+ # Update head
+ loss_out = F.cross_entropy(model.out_head(model.out_ln(hL)), y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ a = credits[l]
+ rms = (a**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f = model.blocks[l](hiddens[l].detach())
+ ll = (f * (a/rms)).sum(-1).mean()
+ block_opts[l].zero_grad(); ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ # Update embedding
+ a0 = credits[0]
+ rms0 = (a0**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ el = (model.embed(x) * (a0/rms0)).sum(-1).mean()
+ embed_opt.zero_grad(); el.backward(); embed_opt.step()
+
+ total_loss += loss_val.item()*batch; correct += (logits.argmax(1)==y).sum().item(); total += batch
+
+ for s in scheds: s.step()
+ test_acc = evaluate(model, test_loader, device)
+ log['train_loss'].append(total_loss/total); log['test_acc'].append(test_acc)
+ log['vloss'].append(total_vloss/total)
+ if epoch % 20 == 0 or epoch == 1:
+ phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}"
+ print(f" [vec_M{M}] Ep {epoch} ({phase}): loss={total_loss/total:.4f}, test={test_acc:.4f}")
+
+ return log, vector_net
+
+
+def compute_diagnostics(model, test_loader, device, method_name, value_net=None, vector_net=None, dfa_Bs=None):
+ model.eval()
+ if value_net: value_net.eval()
+ if vector_net: vector_net.eval()
+ L = model.num_blocks
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device); break
+ batch = x.size(0)
+
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L+1): hiddens_bp[l].retain_grad()
+ F.cross_entropy(logits_bp, y).backward()
+ bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L+1)}
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1; s = e_T.detach()
+
+ results = {'bp_cosine': [], 'perturbation_rho': [], 'nudging': {'0.001':[], '0.003':[], '0.01':[]}}
+ for l in range(L):
+ h_l = hiddens[l].detach(); t_l = torch.full((batch,), l/L, device=device)
+ if method_name == 'dfa':
+ a_l = (s @ dfa_Bs[l].T).detach()
+ elif method_name.startswith('vec'):
+ a_l = vector_net(h_l, t_l, s).detach()
+ results['bp_cosine'].append(float(cosine_similarity_batch(a_l, bp_grads[l])))
+ def make_fwd(sl):
+ def f(h):
+ with torch.no_grad():
+ c=h
+ for i in range(sl,L): c=c+model.blocks[i](c)
+ return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none')
+ return f
+ fwd = make_fwd(l)
+ results['perturbation_rho'].append(float(perturbation_correlation(h_l, a_l, fwd, epsilon=1e-3, M=16)))
+ for eta in [0.001, 0.003, 0.01]:
+ results['nudging'][str(eta)].append(float(nudging_test(h_l, a_l, fwd, eta=eta)))
+ return results
+
+
+def run_config(L, d, method, seed, train_loader, test_loader, device,
+ epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01,
+ M=4, warmup_ratio=0.2, term_weight=1.0, eps=1e-3, beta=1.0):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = ResidualMLP(32*32*3, d, 10, L).to(device)
+ config_str = f"L={L}, d={d}, {method}, s={seed}"
+ if method.startswith('vec'): config_str += f", wr={warmup_ratio}, tw={term_weight}"
+ print(f"\n --- {config_str} ---")
+
+ if method == 'dfa':
+ log, Bs = train_dfa(model, train_loader, test_loader, device, epochs, lr, wd)
+ diag = compute_diagnostics(model, test_loader, device, 'dfa', dfa_Bs=Bs)
+ elif method.startswith('vec'):
+ log, vnet = train_vector_online(model, train_loader, test_loader, device,
+ epochs, lr, lr_fb, wd, M=M,
+ warmup_ratio=warmup_ratio, term_weight=term_weight,
+ eps=eps, beta=beta)
+ diag = compute_diagnostics(model, test_loader, device, 'vec', vector_net=vnet)
+
+ result = {
+ 'method': method, 'L': L, 'd': d, 'seed': seed,
+ 'warmup_ratio': warmup_ratio, 'term_weight': term_weight, 'M': M,
+ 'test_acc': log['test_acc'][-1],
+ 'mean_gamma': float(np.mean(diag['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag['perturbation_rho'])),
+ 'mean_nudge': float(np.mean(diag['nudging']['0.003'])),
+ 'per_layer_gamma': diag['bp_cosine'],
+ 'per_layer_rho': diag['perturbation_rho'],
+ }
+ print(f" Result: acc={result['test_acc']:.4f}, Gamma={result['mean_gamma']:.4f}, "
+ f"rho={result['mean_rho']:.4f}, nudge={result['mean_nudge']:.6f}")
+ return result
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 5C: Online CIFAR Vector Pilot')
+ parser.add_argument('--L', type=int, default=4)
+ parser.add_argument('--d', type=int, default=256)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--M', type=int, default=4)
+ parser.add_argument('--warmup_ratios', type=float, nargs='+', default=[0.0, 0.05, 0.2])
+ parser.add_argument('--term_weights', type=float, nargs='+', default=[1.0, 4.0])
+ parser.add_argument('--pert_eps', type=float, default=1e-3)
+ parser.add_argument('--pert_beta', type=float, default=1.0)
+ parser.add_argument('--seeds', type=int, nargs='+', default=[42])
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--gpu', type=int, default=2)
+ parser.add_argument('--output_dir', type=str, default='results/online_vec_pilot')
+ args = parser.parse_args()
+
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+ train_loader, test_loader = get_cifar10(args.batch_size)
+
+ all_results = []
+
+ for seed in args.seeds:
+ # DFA baseline
+ r = run_config(args.L, args.d, 'dfa', seed, train_loader, test_loader, device,
+ args.epochs, args.lr, args.lr_fb, args.wd)
+ all_results.append(r)
+
+ # Vector field sweep
+ for wr in args.warmup_ratios:
+ for tw in args.term_weights:
+ r = run_config(args.L, args.d, 'vec_eT_M4', seed, train_loader, test_loader, device,
+ args.epochs, args.lr, args.lr_fb, args.wd,
+ M=args.M, warmup_ratio=wr, term_weight=tw,
+ eps=args.pert_eps, beta=args.pert_beta)
+ all_results.append(r)
+
+ # Summary
+ dfa_baselines = {r['seed']: r for r in all_results if r['method'] == 'dfa'}
+ print(f"\n{'='*90}")
+ print("SUMMARY")
+ print(f"{'='*90}")
+ print(f"{'Method':<20} {'seed':>5} {'wr':>5} {'tw':>5} {'Acc':>6} {'Gamma':>7} {'rho':>7} {'nudge':>10} {'S1':>7} {'S2':>7}")
+ print("-" * 90)
+
+ positive = []
+ for r in all_results:
+ dfa = dfa_baselines.get(r['seed'], {})
+ S1 = r['mean_gamma'] - dfa.get('mean_gamma', 0)
+ S2 = r['mean_rho'] - dfa.get('mean_rho', 0)
+ wr_s = f"{r.get('warmup_ratio', '-'):>5.2f}" if r['method'] != 'dfa' else " -"
+ tw_s = f"{r.get('term_weight', '-'):>5.1f}" if r['method'] != 'dfa' else " -"
+ print(f"{r['method']:<20} {r['seed']:>5} {wr_s} {tw_s} {r['test_acc']:>6.4f} "
+ f"{r['mean_gamma']:>7.4f} {r['mean_rho']:>7.4f} {r['mean_nudge']:>10.6f} {S1:>7.4f} {S2:>7.4f}")
+ if r['method'] != 'dfa' and S1 > 0 and S2 > 0:
+ nb = r['mean_nudge'] < dfa.get('mean_nudge', 0)
+ positive.append({**r, 'S1': S1, 'S2': S2, 'nudge_better': nb})
+
+ if positive:
+ print(f"\nPOSITIVE CONFIGS (S1>0 AND S2>0):")
+ for p in positive:
+ print(f" {p['method']} wr={p['warmup_ratio']} tw={p['term_weight']}: "
+ f"S1={p['S1']:.4f} S2={p['S2']:.4f} nudge_better={p['nudge_better']}")
+ else:
+ print(f"\nNO POSITIVE CONFIGS.")
+
+ out_path = os.path.join(args.output_dir, f'pilot_s{args.seeds[0]}.json')
+ with open(out_path, 'w') as f:
+ json.dump(all_results, f, indent=2)
+ print(f"\nSaved to {out_path}")
+
+
+if __name__ == '__main__':
+ main()