summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
Diffstat (limited to 'experiments')
-rw-r--r--experiments/cifar_frozen_vector_credit.py648
-rw-r--r--experiments/cifar_online_vector_credit.py404
-rw-r--r--experiments/vector_credit_audit.py844
3 files changed, 1896 insertions, 0 deletions
diff --git a/experiments/cifar_frozen_vector_credit.py b/experiments/cifar_frozen_vector_credit.py
new file mode 100644
index 0000000..acd26e6
--- /dev/null
+++ b/experiments/cifar_frozen_vector_credit.py
@@ -0,0 +1,648 @@
+"""
+Phase 5B: Frozen CIFAR Vector Credit Transfer.
+
+Test whether direct vector credit field can recover better credit than scalar CB
+on frozen BP-trained CIFAR representations.
+
+Methods compared:
+- DFA (random)
+- StateBridge_eT
+- ScalarCB_eT
+- ScalarCB_deltaL
+- VectorField_eT_M{4,8,16}
+- VectorField_deltaL_M{4,8,16} (if resources allow)
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test
+)
+
+
+class VectorCreditNet(nn.Module):
+ """Direct vector credit field: a_phi(h_l, t_l, s) -> R^d."""
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def evaluate(model, test_loader, device):
+ model.eval()
+ correct, total = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def train_bp_reference(model, train_loader, test_loader, device, epochs=100, lr=1e-3, wd=0.01):
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+ for epoch in range(1, epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ scheduler.step()
+ if epoch % 20 == 0 or epoch == 1:
+ test_acc = evaluate(model, test_loader, device)
+ print(f" [BP ref] Ep {epoch}: loss={total_loss/total:.4f}, test={test_acc:.4f}")
+ test_acc = evaluate(model, test_loader, device)
+ print(f" [BP ref] Final: {test_acc:.4f}")
+ return test_acc
+
+
+# =============================================================================
+# Estimator training functions (all on frozen model)
+# =============================================================================
+
+def train_state_bridge_frozen(model, train_loader, device, epochs, lr_fb):
+ d = model.d_hidden
+ L = model.num_blocks
+ state_pred = StateBridgeNet(d_hidden=d, s_dim=10, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ state_opt = optim.Adam(state_pred.parameters(), lr=lr_fb)
+ model.eval()
+ for epoch in range(1, epochs + 1):
+ state_pred.train()
+ total_loss, n = 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ hL_det = hiddens[-1].detach()
+ state_loss = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ pred_hL = state_pred(h_l_det, t_l, s)
+ target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0)
+ state_loss += (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean()
+ state_loss /= L
+ state_opt.zero_grad()
+ state_loss.backward()
+ state_opt.step()
+ total_loss += state_loss.item() * batch
+ n += batch
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [SB] Ep {epoch}: loss={total_loss/n:.6f}")
+ return state_pred
+
+
+def train_scalar_cb_frozen(model, train_loader, device, epochs, lr_fb, s_type='eT',
+ lam=0.1, K=4, sigma_bridge=0.05, ema_momentum=0.995,
+ term_grad_weight=1.0):
+ d = model.d_hidden
+ L = model.num_blocks
+ s_dim = 10 if s_type == 'eT' else d
+ value_net = ValueNet(d_hidden=d, s_dim=s_dim, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ value_net_ema = create_ema_model(value_net)
+ value_opt = optim.Adam(value_net.parameters(), lr=lr_fb)
+ model.eval()
+ for epoch in range(1, epochs + 1):
+ value_net.train()
+ total_vloss, n = 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+ hL_det = hiddens[-1].detach()
+ if s_type == 'eT':
+ s = e_T.detach()
+ else:
+ hL_req = hL_det.clone().requires_grad_(True)
+ logits_for_s = model.out_head(model.out_ln(hL_req))
+ ce_for_s = F.cross_entropy(logits_for_s, y, reduction='sum')
+ s = torch.autograd.grad(ce_for_s, hL_req, create_graph=False)[0].detach()
+
+ t_L = torch.ones(batch, device=device)
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ loss_tgrad = torch.tensor(0.0, device=device)
+ if term_grad_weight > 0:
+ hL_req = hL_det.clone().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ hL_req2 = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req2))
+ ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum')
+ a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach()
+ loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ t_next = torch.full((batch,), (l + 1) / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K):
+ noise = sigma_bridge * torch.randn_like(h_next_det)
+ V_next = value_net_ema(h_next_det + noise, t_next, s)
+ log_terms.append(-V_next / lam)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K))
+ loss_bridge += ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge /= L
+
+ vloss = loss_term + loss_bridge + term_grad_weight * loss_tgrad
+ value_opt.zero_grad()
+ vloss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ value_opt.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+ total_vloss += vloss.item() * batch
+ n += batch
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [CB_{s_type}] Ep {epoch}: vloss={total_vloss/n:.6f}")
+ return value_net
+
+
+def train_vector_field_frozen(model, train_loader, device, epochs, lr_fb,
+ s_type='eT', M=4, eps=1e-3, beta=1.0,
+ term_weight=1.0):
+ """
+ Train vector credit field on frozen CIFAR features.
+ Layer subsampling: each batch, randomly pick one layer for perturbation target.
+ Terminal matching always uses layer L.
+ """
+ d = model.d_hidden
+ L = model.num_blocks
+ s_dim = 10 if s_type == 'eT' else d
+
+ vector_net = VectorCreditNet(d_hidden=d, s_dim=s_dim, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ vec_opt = optim.Adam(vector_net.parameters(), lr=lr_fb)
+ model.eval()
+
+ for epoch in range(1, epochs + 1):
+ vector_net.train()
+ total_vloss, n = 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+
+ hL_det = hiddens[-1].detach()
+
+ # Compute s
+ if s_type == 'eT':
+ s = e_T.detach()
+ else:
+ hL_req = hL_det.clone().requires_grad_(True)
+ logits_for_s = model.out_head(model.out_ln(hL_req))
+ ce_for_s = F.cross_entropy(logits_for_s, y, reduction='sum')
+ s = torch.autograd.grad(ce_for_s, hL_req, create_graph=False)[0].detach()
+
+ # Terminal matching
+ loss_term = torch.tensor(0.0, device=device)
+ if term_weight > 0:
+ t_L = torch.ones(batch, device=device)
+ a_terminal = vector_net(hL_det, t_L, s)
+ hL_req = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req))
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_terminal - delta_L) ** 2).sum(dim=-1).mean()
+
+ # Perturbation target — subsample 1 random layer per batch
+ l = np.random.randint(0, L)
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ a_l = vector_net(h_l_det, t_l, s)
+
+ loss_proj = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l_det)
+ v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)
+
+ with torch.no_grad():
+ # Use model.forward_from_layer for tail forward
+ logits_plus = model.forward_from_layer(h_l_det + eps * v, l)
+ loss_plus = F.cross_entropy(logits_plus, y, reduction='none')
+ logits_minus = model.forward_from_layer(h_l_det - eps * v, l)
+ loss_minus = F.cross_entropy(logits_minus, y, reduction='none')
+ g_j = (loss_plus - loss_minus) / (2 * eps)
+
+ pred_j = (a_l * v).sum(dim=-1)
+ loss_proj = loss_proj + ((pred_j - g_j.detach()) ** 2).mean()
+ loss_proj = loss_proj / M
+
+ vloss = term_weight * loss_term + beta * loss_proj
+ vec_opt.zero_grad()
+ vloss.backward()
+ torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0)
+ vec_opt.step()
+ total_vloss += vloss.item() * batch
+ n += batch
+
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [vec_{s_type}_M{M}] Ep {epoch}: vloss={total_vloss/n:.6f}")
+
+ return vector_net
+
+
+# =============================================================================
+# Evaluation
+# =============================================================================
+def evaluate_all(model, test_loader, device, estimators):
+ """Evaluate credit quality for all estimators on frozen features."""
+ model.eval()
+ d = model.d_hidden
+ L = model.num_blocks
+
+ # DFA baseline
+ dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+
+ # Use multiple test batches for robust Gamma, single batch for rho/nudge (expensive)
+ results = {}
+ for name in list(estimators.keys()) + ['dfa']:
+ results[name] = {
+ 'bp_cosine': [[] for _ in range(L)],
+ 'perturbation_rho': [0.0] * L,
+ 'nudging_0.001': [0.0] * L,
+ 'nudging_0.003': [0.0] * L,
+ 'nudging_0.01': [0.0] * L,
+ }
+
+ n_batches = min(10, len(test_loader))
+ batch_idx = 0
+
+ for x, y in test_loader:
+ if batch_idx >= n_batches:
+ break
+ batch_idx += 1
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ # BP gradients
+ for p in model.parameters():
+ p.requires_grad_(True)
+ model.zero_grad()
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ loss_bp = F.cross_entropy(logits_bp, y)
+ loss_bp.backward()
+ bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)}
+ for p in model.parameters():
+ p.requires_grad_(False)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s_eT = e_T.detach()
+
+ hL_det = hiddens[-1].detach()
+ hL_req = hL_det.clone().requires_grad_(True)
+ logits_delta = model.out_head(model.out_ln(hL_req))
+ ce_delta = F.cross_entropy(logits_delta, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce_delta, hL_req, create_graph=False)[0].detach()
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ def make_fwd_fn(start_l):
+ def fwd_fn(h):
+ with torch.no_grad():
+ curr = h
+ for i in range(start_l, L):
+ curr = curr + model.blocks[i](curr)
+ out = model.out_head(model.out_ln(curr))
+ return F.cross_entropy(out, y, reduction='none')
+ return fwd_fn
+ fwd_fn = make_fwd_fn(l)
+
+ # DFA
+ a_dfa = (s_eT @ dfa_Bs[l].T).detach()
+ results['dfa']['bp_cosine'][l].append(cosine_similarity_batch(a_dfa, bp_grads[l]))
+ if batch_idx == 1:
+ results['dfa']['perturbation_rho'][l] = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=32)
+ for eta in [0.001, 0.003, 0.01]:
+ results['dfa'][f'nudging_{eta}'][l] = nudging_test(h_l, a_dfa, fwd_fn, eta=eta)
+
+ # Other estimators
+ for name, est in estimators.items():
+ if est['type'] == 'sb':
+ net = est['net']
+ net.eval()
+ h_l_req = h_l.clone().requires_grad_(True)
+ pred_hL = net(h_l_req, t_l, s_eT)
+ pred_logits = model.out_head(model.out_ln(pred_hL))
+ pred_loss = F.cross_entropy(pred_logits, y, reduction='sum')
+ a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach()
+ elif est['type'] == 'cb':
+ net = est['net']
+ net.eval()
+ s = s_eT if est['s_type'] == 'eT' else delta_L
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = net(h_l_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach()
+ elif est['type'] == 'vec':
+ net = est['net']
+ net.eval()
+ s = s_eT if est['s_type'] == 'eT' else delta_L
+ a_l = net(h_l, t_l, s).detach()
+
+ results[name]['bp_cosine'][l].append(cosine_similarity_batch(a_l, bp_grads[l]))
+ if batch_idx == 1:
+ results[name]['perturbation_rho'][l] = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=32)
+ for eta in [0.001, 0.003, 0.01]:
+ results[name][f'nudging_{eta}'][l] = nudging_test(h_l, a_l, fwd_fn, eta=eta)
+
+ # Average bp_cosine
+ for name in results:
+ for l in range(L):
+ vals = results[name]['bp_cosine'][l]
+ results[name]['bp_cosine'][l] = float(np.mean(vals)) if vals else 0.0
+ return results
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(batch_size=args.batch_size)
+ input_dim = 32 * 32 * 3
+
+ # Step 1: Load/train BP reference
+ bp_ckpt = os.path.join(args.output_dir, f'bp_ref_L{args.num_blocks}_d{args.d_hidden}_s{args.seed}.pt')
+ model = ResidualMLP(input_dim, args.d_hidden, 10, args.num_blocks).to(device)
+
+ # Try loading from frozen_cifar directory first
+ alt_ckpt = f'results/frozen_cifar/bp_ref_L{args.num_blocks}_d{args.d_hidden}_s{args.seed}.pt'
+ if os.path.exists(alt_ckpt) and not args.retrain_bp:
+ print(f" Loading BP ref from {alt_ckpt}")
+ model.load_state_dict(torch.load(alt_ckpt, map_location=device))
+ bp_acc = evaluate(model, test_loader, device)
+ elif os.path.exists(bp_ckpt) and not args.retrain_bp:
+ print(f" Loading BP ref from {bp_ckpt}")
+ model.load_state_dict(torch.load(bp_ckpt, map_location=device))
+ bp_acc = evaluate(model, test_loader, device)
+ else:
+ bp_acc = train_bp_reference(model, train_loader, test_loader, device,
+ epochs=args.bp_epochs, lr=args.lr, wd=args.wd)
+ torch.save(model.state_dict(), bp_ckpt)
+ print(f" BP ref acc: {bp_acc:.4f}")
+
+ model.eval()
+ for p in model.parameters():
+ p.requires_grad_(False)
+
+ L = args.num_blocks
+ d = args.d_hidden
+
+ # Step 2: Train estimators
+ print(f"\n{'='*60}")
+ print(f"Training estimators (L={L}, d={d}, {args.estimator_epochs} epochs)")
+ print(f"{'='*60}")
+
+ estimators = {}
+
+ # StateBridge_eT
+ print("\n--- StateBridge_eT ---")
+ torch.manual_seed(args.seed + 1000)
+ sb = train_state_bridge_frozen(model, train_loader, device, args.estimator_epochs, args.lr_fb)
+ estimators['sb_eT'] = {'type': 'sb', 'net': sb, 's_type': 'eT'}
+
+ # ScalarCB_eT
+ print("\n--- ScalarCB_eT ---")
+ torch.manual_seed(args.seed + 2000)
+ cb_eT = train_scalar_cb_frozen(model, train_loader, device, args.estimator_epochs, args.lr_fb,
+ s_type='eT', term_grad_weight=args.term_grad_weight)
+ estimators['cb_eT'] = {'type': 'cb', 'net': cb_eT, 's_type': 'eT'}
+
+ # ScalarCB_deltaL
+ print("\n--- ScalarCB_deltaL ---")
+ torch.manual_seed(args.seed + 3000)
+ cb_dL = train_scalar_cb_frozen(model, train_loader, device, args.estimator_epochs, args.lr_fb,
+ s_type='deltaL', term_grad_weight=args.term_grad_weight)
+ estimators['cb_deltaL'] = {'type': 'cb', 'net': cb_dL, 's_type': 'deltaL'}
+
+ # Vector fields
+ for M in args.M_values:
+ for s_type in args.vec_s_types:
+ tag = f'vec_{s_type}_M{M}'
+ print(f"\n--- {tag} ---")
+ torch.manual_seed(args.seed + 4000 + M * 100 + (0 if s_type == 'eT' else 1))
+ vnet = train_vector_field_frozen(model, train_loader, device,
+ args.estimator_epochs, args.lr_fb,
+ s_type=s_type, M=M, eps=args.pert_eps,
+ beta=args.pert_beta, term_weight=args.term_weight_vec)
+ estimators[tag] = {'type': 'vec', 'net': vnet, 's_type': s_type}
+
+ # Step 3: Evaluate
+ print(f"\n{'='*60}")
+ print("Evaluating credit quality")
+ print(f"{'='*60}")
+ results = evaluate_all(model, test_loader, device, estimators)
+
+ # Print summary
+ all_methods = ['dfa', 'sb_eT', 'cb_eT', 'cb_deltaL'] + \
+ [f'vec_{st}_M{M}' for M in args.M_values for st in args.vec_s_types]
+ labels = {
+ 'dfa': 'DFA', 'sb_eT': 'StateBridge_eT',
+ 'cb_eT': 'ScalarCB_eT', 'cb_deltaL': 'ScalarCB_deltaL',
+ }
+ for M in args.M_values:
+ for st in args.vec_s_types:
+ labels[f'vec_{st}_M{M}'] = f'Vec_{st}_M{M}'
+
+ print(f"\n{'Method':<25} {'Gamma':>8} {'rho':>8} {'nudge':>10}")
+ print("-" * 55)
+
+ summary = {}
+ for m in all_methods:
+ if m not in results:
+ continue
+ r = results[m]
+ mg = np.mean(r['bp_cosine'])
+ mr = np.mean(r['perturbation_rho'])
+ mn = np.mean(r['nudging_0.003'])
+ summary[m] = {'mean_gamma': float(mg), 'mean_rho': float(mr), 'mean_nudge': float(mn)}
+ print(f"{labels.get(m, m):<25} {mg:>8.4f} {mr:>8.4f} {mn:>10.6f}")
+
+ # Per-layer detail
+ print(f"\n--- Per-layer Gamma ---")
+ for l in range(L):
+ row = f" L{l}: "
+ for m in all_methods:
+ if m in results:
+ row += f" {results[m]['bp_cosine'][l]:>8.4f}"
+ print(row)
+
+ print(f"\n--- Per-layer rho ---")
+ for l in range(L):
+ row = f" L{l}: "
+ for m in all_methods:
+ if m in results:
+ row += f" {results[m]['perturbation_rho'][l]:>8.4f}"
+ print(row)
+
+ # Save
+ save_data = {
+ 'config': {
+ 'num_blocks': L, 'd_hidden': d, 'seed': args.seed,
+ 'bp_acc': float(bp_acc), 'estimator_epochs': args.estimator_epochs,
+ },
+ 'summary': summary,
+ 'per_layer': {m: {
+ 'bp_cosine': results[m]['bp_cosine'],
+ 'perturbation_rho': results[m]['perturbation_rho'],
+ 'nudging_0.003': results[m]['nudging_0.003'],
+ } for m in all_methods if m in results},
+ }
+ out_path = os.path.join(args.output_dir,
+ f'frozen_vec_L{L}_d{d}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(save_data, f, indent=2)
+ print(f"\nResults saved to {out_path}")
+
+ # Judgment
+ cb_eT_gamma = summary.get('cb_eT', {}).get('mean_gamma', 0)
+ cb_eT_rho = summary.get('cb_eT', {}).get('mean_rho', 0)
+ best_vec_gamma = max(summary.get(m, {}).get('mean_gamma', 0) for m in summary if m.startswith('vec_'))
+ best_vec_rho = max(summary.get(m, {}).get('mean_rho', 0) for m in summary if m.startswith('vec_'))
+ best_vec_name = max((m for m in summary if m.startswith('vec_')),
+ key=lambda m: summary[m]['mean_gamma'] + summary[m]['mean_rho'],
+ default='none')
+
+ print(f"\n{'='*60}")
+ print("JUDGMENT")
+ print(f"{'='*60}")
+ print(f"ScalarCB_eT: Gamma={cb_eT_gamma:.4f}, rho={cb_eT_rho:.4f}")
+ print(f"Best vector ({best_vec_name}): Gamma={best_vec_gamma:.4f}, rho={best_vec_rho:.4f}")
+
+ dg = best_vec_gamma - cb_eT_gamma
+ dr = best_vec_rho - cb_eT_rho
+ print(f"Delta: Gamma={dg:+.4f}, rho={dr:+.4f}")
+
+ if dg >= 0.05 and dr >= 0.05:
+ print("TRANSFER SUCCESS: Vector field significantly outperforms scalar CB on frozen CIFAR.")
+ elif dg > 0 and dr > 0:
+ print("MARGINAL: Vector field slightly better, but deltas below 0.05 threshold.")
+ else:
+ print("TRANSFER FAILED: Vector field does not outperform scalar CB on frozen CIFAR.")
+
+ return save_data
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 5B: Frozen CIFAR Vector Transfer')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--bp_epochs', type=int, default=100)
+ parser.add_argument('--estimator_epochs', type=int, default=100)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--term_grad_weight', type=float, default=1.0)
+ parser.add_argument('--term_weight_vec', type=float, default=1.0)
+ parser.add_argument('--pert_eps', type=float, default=1e-3)
+ parser.add_argument('--pert_beta', type=float, default=1.0)
+ parser.add_argument('--M_values', type=int, nargs='+', default=[4, 8, 16])
+ parser.add_argument('--vec_s_types', type=str, nargs='+', default=['eT'])
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=2)
+ parser.add_argument('--output_dir', type=str, default='results/frozen_cifar_vec')
+ parser.add_argument('--retrain_bp', action='store_true')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/cifar_online_vector_credit.py b/experiments/cifar_online_vector_credit.py
new file mode 100644
index 0000000..3a3762c
--- /dev/null
+++ b/experiments/cifar_online_vector_credit.py
@@ -0,0 +1,404 @@
+"""
+Phase 5C: Online Shallow CIFAR Vector Credit Pilot.
+
+Minimal pilot: does vector field's frozen credit gain translate to online training?
+
+Compare DFA, ScalarCB_eT, VectorField_eT_M4 on CIFAR-10, L=4, d=256.
+Sweep warmup_ratio and term_weight.
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test
+)
+
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def evaluate(model, test_loader, device):
+ model.eval()
+ correct, total = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def train_dfa(model, train_loader, test_loader, device, epochs, lr, wd):
+ d = model.d_hidden
+ L = model.num_blocks
+ Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = {'train_loss': [], 'test_acc': []}
+ for epoch in range(1, epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ hL = hiddens[-1].detach()
+ loss_out = F.cross_entropy(model.out_head(model.out_ln(hL)), y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ for l in range(L):
+ a = (e_T @ Bs[l].T).detach()
+ rms = (a**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f = model.blocks[l](hiddens[l].detach())
+ ll = (f * (a/rms)).sum(-1).mean()
+ block_opts[l].zero_grad(); ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ el = (model.embed(x) * (a0/rms0)).sum(-1).mean()
+ embed_opt.zero_grad(); el.backward(); embed_opt.step()
+ total_loss += loss_val.item() * batch; correct += (logits.argmax(1) == y).sum().item(); total += batch
+ for s in scheds: s.step()
+ test_acc = evaluate(model, test_loader, device)
+ log['train_loss'].append(total_loss/total); log['test_acc'].append(test_acc)
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [DFA] Ep {epoch}: loss={total_loss/total:.4f}, test={test_acc:.4f}")
+ return log, Bs
+
+
+def train_vector_online(model, train_loader, test_loader, device, epochs, lr, lr_fb, wd,
+ M=4, warmup_ratio=0.2, term_weight=1.0, eps=1e-3, beta=1.0):
+ d = model.d_hidden
+ L = model.num_blocks
+ warmup_epochs = max(1, int(epochs * warmup_ratio))
+
+ vector_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd)
+ vec_opt = optim.Adam(vector_net.parameters(), lr=lr_fb)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+
+ log = {'train_loss': [], 'test_acc': [], 'vloss': []}
+
+ for epoch in range(1, epochs + 1):
+ model.train(); vector_net.train()
+ credit_blend = 0.0 if epoch <= warmup_epochs else min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+ total_loss, correct, total, total_vloss = 0, 0, 0, 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ hL = hiddens[-1].detach()
+
+ # Train vector net: terminal matching
+ loss_term = torch.tensor(0.0, device=device)
+ if term_weight > 0:
+ t_L = torch.ones(batch, device=device)
+ a_term = vector_net(hL, t_L, s)
+ hL_req = hL.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req))
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_term - delta_L) ** 2).sum(-1).mean()
+
+ # Perturbation target: subsample 1 layer
+ l_train = np.random.randint(0, L)
+ h_l = hiddens[l_train].detach()
+ t_l = torch.full((batch,), l_train / L, device=device)
+ a_l = vector_net(h_l, t_l, s)
+
+ loss_proj = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l)
+ v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(model.forward_from_layer(h_l + eps*v, l_train), y, reduction='none')
+ lm = F.cross_entropy(model.forward_from_layer(h_l - eps*v, l_train), y, reduction='none')
+ g_j = (lp - lm) / (2*eps)
+ loss_proj = loss_proj + (((a_l * v).sum(-1) - g_j.detach())**2).mean()
+ loss_proj = loss_proj / M
+
+ vloss = term_weight * loss_term + beta * loss_proj
+ vec_opt.zero_grad(); vloss.backward()
+ torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0)
+ vec_opt.step()
+ total_vloss += vloss.item() * batch
+
+ # Compute credits
+ with torch.no_grad():
+ vec_credits = [vector_net(hiddens[l].detach(),
+ torch.full((batch,), l/L, device=device), s).detach() for l in range(L)]
+ dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ credits.append(vec_credits[l])
+ elif credit_blend <= 0.0:
+ credits.append(dfa_credits[l])
+ else:
+ vr = (vec_credits[l]**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ dr = (dfa_credits[l]**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ credits.append(credit_blend * vec_credits[l]/vr + (1-credit_blend) * dfa_credits[l]/dr)
+
+ # Update head
+ loss_out = F.cross_entropy(model.out_head(model.out_ln(hL)), y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ a = credits[l]
+ rms = (a**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f = model.blocks[l](hiddens[l].detach())
+ ll = (f * (a/rms)).sum(-1).mean()
+ block_opts[l].zero_grad(); ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ # Update embedding
+ a0 = credits[0]
+ rms0 = (a0**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ el = (model.embed(x) * (a0/rms0)).sum(-1).mean()
+ embed_opt.zero_grad(); el.backward(); embed_opt.step()
+
+ total_loss += loss_val.item()*batch; correct += (logits.argmax(1)==y).sum().item(); total += batch
+
+ for s in scheds: s.step()
+ test_acc = evaluate(model, test_loader, device)
+ log['train_loss'].append(total_loss/total); log['test_acc'].append(test_acc)
+ log['vloss'].append(total_vloss/total)
+ if epoch % 20 == 0 or epoch == 1:
+ phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}"
+ print(f" [vec_M{M}] Ep {epoch} ({phase}): loss={total_loss/total:.4f}, test={test_acc:.4f}")
+
+ return log, vector_net
+
+
+def compute_diagnostics(model, test_loader, device, method_name, value_net=None, vector_net=None, dfa_Bs=None):
+ model.eval()
+ if value_net: value_net.eval()
+ if vector_net: vector_net.eval()
+ L = model.num_blocks
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device); break
+ batch = x.size(0)
+
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L+1): hiddens_bp[l].retain_grad()
+ F.cross_entropy(logits_bp, y).backward()
+ bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L+1)}
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1; s = e_T.detach()
+
+ results = {'bp_cosine': [], 'perturbation_rho': [], 'nudging': {'0.001':[], '0.003':[], '0.01':[]}}
+ for l in range(L):
+ h_l = hiddens[l].detach(); t_l = torch.full((batch,), l/L, device=device)
+ if method_name == 'dfa':
+ a_l = (s @ dfa_Bs[l].T).detach()
+ elif method_name.startswith('vec'):
+ a_l = vector_net(h_l, t_l, s).detach()
+ results['bp_cosine'].append(float(cosine_similarity_batch(a_l, bp_grads[l])))
+ def make_fwd(sl):
+ def f(h):
+ with torch.no_grad():
+ c=h
+ for i in range(sl,L): c=c+model.blocks[i](c)
+ return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none')
+ return f
+ fwd = make_fwd(l)
+ results['perturbation_rho'].append(float(perturbation_correlation(h_l, a_l, fwd, epsilon=1e-3, M=16)))
+ for eta in [0.001, 0.003, 0.01]:
+ results['nudging'][str(eta)].append(float(nudging_test(h_l, a_l, fwd, eta=eta)))
+ return results
+
+
+def run_config(L, d, method, seed, train_loader, test_loader, device,
+ epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01,
+ M=4, warmup_ratio=0.2, term_weight=1.0, eps=1e-3, beta=1.0):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = ResidualMLP(32*32*3, d, 10, L).to(device)
+ config_str = f"L={L}, d={d}, {method}, s={seed}"
+ if method.startswith('vec'): config_str += f", wr={warmup_ratio}, tw={term_weight}"
+ print(f"\n --- {config_str} ---")
+
+ if method == 'dfa':
+ log, Bs = train_dfa(model, train_loader, test_loader, device, epochs, lr, wd)
+ diag = compute_diagnostics(model, test_loader, device, 'dfa', dfa_Bs=Bs)
+ elif method.startswith('vec'):
+ log, vnet = train_vector_online(model, train_loader, test_loader, device,
+ epochs, lr, lr_fb, wd, M=M,
+ warmup_ratio=warmup_ratio, term_weight=term_weight,
+ eps=eps, beta=beta)
+ diag = compute_diagnostics(model, test_loader, device, 'vec', vector_net=vnet)
+
+ result = {
+ 'method': method, 'L': L, 'd': d, 'seed': seed,
+ 'warmup_ratio': warmup_ratio, 'term_weight': term_weight, 'M': M,
+ 'test_acc': log['test_acc'][-1],
+ 'mean_gamma': float(np.mean(diag['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag['perturbation_rho'])),
+ 'mean_nudge': float(np.mean(diag['nudging']['0.003'])),
+ 'per_layer_gamma': diag['bp_cosine'],
+ 'per_layer_rho': diag['perturbation_rho'],
+ }
+ print(f" Result: acc={result['test_acc']:.4f}, Gamma={result['mean_gamma']:.4f}, "
+ f"rho={result['mean_rho']:.4f}, nudge={result['mean_nudge']:.6f}")
+ return result
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 5C: Online CIFAR Vector Pilot')
+ parser.add_argument('--L', type=int, default=4)
+ parser.add_argument('--d', type=int, default=256)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--M', type=int, default=4)
+ parser.add_argument('--warmup_ratios', type=float, nargs='+', default=[0.0, 0.05, 0.2])
+ parser.add_argument('--term_weights', type=float, nargs='+', default=[1.0, 4.0])
+ parser.add_argument('--pert_eps', type=float, default=1e-3)
+ parser.add_argument('--pert_beta', type=float, default=1.0)
+ parser.add_argument('--seeds', type=int, nargs='+', default=[42])
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--gpu', type=int, default=2)
+ parser.add_argument('--output_dir', type=str, default='results/online_vec_pilot')
+ args = parser.parse_args()
+
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+ train_loader, test_loader = get_cifar10(args.batch_size)
+
+ all_results = []
+
+ for seed in args.seeds:
+ # DFA baseline
+ r = run_config(args.L, args.d, 'dfa', seed, train_loader, test_loader, device,
+ args.epochs, args.lr, args.lr_fb, args.wd)
+ all_results.append(r)
+
+ # Vector field sweep
+ for wr in args.warmup_ratios:
+ for tw in args.term_weights:
+ r = run_config(args.L, args.d, 'vec_eT_M4', seed, train_loader, test_loader, device,
+ args.epochs, args.lr, args.lr_fb, args.wd,
+ M=args.M, warmup_ratio=wr, term_weight=tw,
+ eps=args.pert_eps, beta=args.pert_beta)
+ all_results.append(r)
+
+ # Summary
+ dfa_baselines = {r['seed']: r for r in all_results if r['method'] == 'dfa'}
+ print(f"\n{'='*90}")
+ print("SUMMARY")
+ print(f"{'='*90}")
+ print(f"{'Method':<20} {'seed':>5} {'wr':>5} {'tw':>5} {'Acc':>6} {'Gamma':>7} {'rho':>7} {'nudge':>10} {'S1':>7} {'S2':>7}")
+ print("-" * 90)
+
+ positive = []
+ for r in all_results:
+ dfa = dfa_baselines.get(r['seed'], {})
+ S1 = r['mean_gamma'] - dfa.get('mean_gamma', 0)
+ S2 = r['mean_rho'] - dfa.get('mean_rho', 0)
+ wr_s = f"{r.get('warmup_ratio', '-'):>5.2f}" if r['method'] != 'dfa' else " -"
+ tw_s = f"{r.get('term_weight', '-'):>5.1f}" if r['method'] != 'dfa' else " -"
+ print(f"{r['method']:<20} {r['seed']:>5} {wr_s} {tw_s} {r['test_acc']:>6.4f} "
+ f"{r['mean_gamma']:>7.4f} {r['mean_rho']:>7.4f} {r['mean_nudge']:>10.6f} {S1:>7.4f} {S2:>7.4f}")
+ if r['method'] != 'dfa' and S1 > 0 and S2 > 0:
+ nb = r['mean_nudge'] < dfa.get('mean_nudge', 0)
+ positive.append({**r, 'S1': S1, 'S2': S2, 'nudge_better': nb})
+
+ if positive:
+ print(f"\nPOSITIVE CONFIGS (S1>0 AND S2>0):")
+ for p in positive:
+ print(f" {p['method']} wr={p['warmup_ratio']} tw={p['term_weight']}: "
+ f"S1={p['S1']:.4f} S2={p['S2']:.4f} nudge_better={p['nudge_better']}")
+ else:
+ print(f"\nNO POSITIVE CONFIGS.")
+
+ out_path = os.path.join(args.output_dir, f'pilot_s{args.seeds[0]}.json')
+ with open(out_path, 'w') as f:
+ json.dump(all_results, f, indent=2)
+ print(f"\nSaved to {out_path}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/vector_credit_audit.py b/experiments/vector_credit_audit.py
new file mode 100644
index 0000000..048efb7
--- /dev/null
+++ b/experiments/vector_credit_audit.py
@@ -0,0 +1,844 @@
+"""
+Phase 5A: Vector Credit Field Audit.
+
+Verify that the vector field's gains are real, not implementation artifacts.
+
+4 mandatory sanity checks:
+A. Train/eval direction split (independent random directions)
+B. Shuffled-target control (permute g_j within batch)
+C. No-terminal ablation (L_term = 0)
+D. One-sided vs symmetric finite difference
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test
+)
+
+
+# =============================================================================
+# Synthetic teacher-student
+# =============================================================================
+class TeacherNet(nn.Module):
+ def __init__(self, d_hidden, num_classes, num_blocks, alpha=1.0, seed=0):
+ super().__init__()
+ self.d_hidden = d_hidden
+ self.num_blocks = num_blocks
+ self.alpha = alpha
+ rng = torch.Generator().manual_seed(seed)
+ self.Ws = nn.ParameterList()
+ for _ in range(num_blocks):
+ W = torch.randn(d_hidden, d_hidden, generator=rng) * 0.3 / (d_hidden ** 0.5)
+ U, S, Vh = torch.linalg.svd(W, full_matrices=False)
+ S_clamped = S.clamp(max=0.3)
+ W = U @ torch.diag(S_clamped) @ Vh
+ self.Ws.append(nn.Parameter(W, requires_grad=False))
+ self.U = nn.Parameter(
+ torch.randn(num_classes, d_hidden, generator=rng) / (d_hidden ** 0.5),
+ requires_grad=False)
+
+ def phi(self, z):
+ return (1 - self.alpha) * z + self.alpha * torch.tanh(z)
+
+ def forward(self, x):
+ h = x
+ for W in self.Ws:
+ h = h + self.phi(h @ W.T)
+ return h @ self.U.T
+
+
+class StudentBlock(nn.Module):
+ def __init__(self, d_hidden, alpha=1.0):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.w = nn.Linear(d_hidden, d_hidden, bias=False)
+ nn.init.normal_(self.w.weight, std=0.01)
+ self.alpha = alpha
+
+ def phi(self, z):
+ return (1 - self.alpha) * z + self.alpha * torch.tanh(z)
+
+ def forward(self, h):
+ return self.w(self.phi(self.ln(h)))
+
+
+class StudentNet(nn.Module):
+ def __init__(self, d_hidden, num_classes, num_blocks, alpha=1.0):
+ super().__init__()
+ self.blocks = nn.ModuleList([StudentBlock(d_hidden, alpha) for _ in range(num_blocks)])
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.d_hidden = d_hidden
+ self.num_blocks = num_blocks
+
+ def forward(self, x, return_hidden=False):
+ h = x
+ hiddens = [h] if return_hidden else None
+ for block in self.blocks:
+ f = block(h)
+ h = h + f
+ if return_hidden:
+ hiddens.append(h)
+ logits = self.out_head(h)
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ def forward_from_layer(self, h, start_layer):
+ for i in range(start_layer, self.num_blocks):
+ f = self.blocks[i](h)
+ h = h + f
+ return self.out_head(h)
+
+
+class VectorCreditNet(nn.Module):
+ """Direct vector credit field: a_phi(h_l, t_l, s) -> R^d."""
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+def generate_batch(teacher, d_hidden, num_classes, batch_size, device):
+ x = torch.randn(batch_size, d_hidden, device=device)
+ with torch.no_grad():
+ teacher_logits = teacher(x)
+ y = teacher_logits.argmax(dim=-1)
+ return x, y
+
+
+# =============================================================================
+# Training: vector field with audit controls
+# =============================================================================
+def train_vector_field_audit(model, teacher, device, args, M=4,
+ use_terminal=True,
+ shuffle_targets=False,
+ use_central_diff=True,
+ tag='vec'):
+ """
+ Train vector credit field with configurable audit controls.
+
+ Args:
+ use_terminal: if False, L_term = 0 (no-terminal ablation)
+ shuffle_targets: if True, permute g_j within batch (leak check)
+ use_central_diff: if True, central difference; if False, one-sided
+ tag: label for printing
+ """
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = args.num_classes
+
+ vector_net = VectorCreditNet(d_hidden=d, s_dim=num_classes, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+
+ Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)]
+
+ block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=0.01) for b in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=0.01)
+ vec_opt = optim.Adam(vector_net.parameters(), lr=args.lr_fb)
+
+ warmup_epochs = max(1, int(args.epochs * args.warmup_ratio))
+ eps = args.pert_eps
+ beta = args.pert_beta
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ vector_net.train()
+
+ if epoch <= warmup_epochs:
+ credit_blend = 0.0
+ else:
+ credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+
+ total_loss, correct, total = 0, 0, 0
+ total_vloss = 0
+
+ for _ in range(args.steps_per_epoch):
+ x, y = generate_batch(teacher, d, num_classes, args.batch_size, device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # --- Terminal matching ---
+ loss_term = torch.tensor(0.0, device=device)
+ if use_terminal:
+ t_L = torch.ones(batch, device=device)
+ a_terminal = vector_net(hL_det, t_L, s)
+ hL_req = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(hL_req)
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_terminal - delta_L) ** 2).sum(dim=-1).mean()
+
+ # --- Perturbation directional targets ---
+ # IMPORTANT: training directions are sampled fresh each step.
+ # Evaluation uses independently sampled directions (see compute_diagnostics).
+ loss_proj = torch.tensor(0.0, device=device)
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ a_l = vector_net(h_l_det, t_l, s)
+
+ layer_proj_loss = 0.0
+ for _ in range(M):
+ v = torch.randn_like(h_l_det)
+ v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)
+
+ with torch.no_grad():
+ if use_central_diff:
+ # Central difference: [loss(h+eps*v) - loss(h-eps*v)] / (2*eps)
+ logits_plus = model.forward_from_layer(h_l_det + eps * v, l)
+ loss_plus = F.cross_entropy(logits_plus, y, reduction='none')
+ logits_minus = model.forward_from_layer(h_l_det - eps * v, l)
+ loss_minus = F.cross_entropy(logits_minus, y, reduction='none')
+ g_j = (loss_plus - loss_minus) / (2 * eps)
+ else:
+ # One-sided difference: [loss(h+eps*v) - loss(h)] / eps
+ logits_base = model.forward_from_layer(h_l_det, l)
+ loss_base = F.cross_entropy(logits_base, y, reduction='none')
+ logits_plus = model.forward_from_layer(h_l_det + eps * v, l)
+ loss_plus = F.cross_entropy(logits_plus, y, reduction='none')
+ g_j = (loss_plus - loss_base) / eps
+
+ # Shuffled-target control: permute g_j within batch
+ if shuffle_targets:
+ perm = torch.randperm(batch, device=device)
+ g_j = g_j[perm]
+
+ pred_j = (a_l * v).sum(dim=-1)
+ layer_proj_loss = layer_proj_loss + ((pred_j - g_j.detach()) ** 2).mean()
+
+ loss_proj = loss_proj + layer_proj_loss / M
+ loss_proj = loss_proj / L
+
+ vec_loss = loss_term + beta * loss_proj
+ vec_opt.zero_grad()
+ vec_loss.backward()
+ torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0)
+ vec_opt.step()
+ total_vloss += vec_loss.item() * batch
+
+ # --- Block updates ---
+ with torch.no_grad():
+ vec_credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ a_l = vector_net(h_l_det, t_l, s)
+ vec_credits.append(a_l.detach())
+
+ dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ credits.append(vec_credits[l])
+ elif credit_blend <= 0.0:
+ credits.append(dfa_credits[l])
+ else:
+ vc_rms = (vec_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ credits.append(credit_blend * vec_credits[l] / vc_rms +
+ (1 - credit_blend) * dfa_credits[l] / dfa_rms)
+
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a / rms)).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ if epoch % 20 == 0 or epoch == 1:
+ acc = correct / total
+ print(f" [{tag}] Ep {epoch}: loss={total_loss/total:.4f}, acc={acc:.4f}, "
+ f"vloss={total_vloss/total:.6f}")
+
+ return vector_net
+
+
+def train_scalar_cb(model, teacher, device, args):
+ """Scalar credit bridge baseline."""
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = args.num_classes
+
+ value_net = ValueNet(d_hidden=d, s_dim=num_classes, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ value_net_ema = create_ema_model(value_net)
+
+ Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)]
+
+ block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=0.01) for b in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=0.01)
+ value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb)
+
+ warmup_epochs = max(1, int(args.epochs * args.warmup_ratio))
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ value_net.train()
+
+ if epoch <= warmup_epochs:
+ credit_blend = 0.0
+ else:
+ credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+
+ total_loss, correct, total = 0, 0, 0
+ for _ in range(args.steps_per_epoch):
+ x, y = generate_batch(teacher, d, num_classes, args.batch_size, device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+
+ hL_det = hiddens[-1].detach()
+ t_L = torch.ones(batch, device=device)
+ V_term = value_net(hL_det, t_L, s)
+ loss_term = ((V_term - true_loss) ** 2).mean()
+
+ hL_req = hL_det.clone().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ hL_req2 = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(hL_req2)
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ a_L_exact = torch.autograd.grad(ce, hL_req2, create_graph=False)[0].detach()
+ loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ t_next = torch.full((batch,), (l + 1) / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ with torch.no_grad():
+ h_next = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(args.K):
+ noise = args.sigma_bridge * torch.randn_like(h_next)
+ V_next = value_net_ema(h_next + noise, t_next, s)
+ log_terms.append(-V_next / args.lam)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -args.lam * (torch.logsumexp(log_stack, dim=-1) - np.log(args.K))
+ loss_bridge += ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge /= L
+
+ vloss = loss_term + loss_bridge + args.term_grad_weight * loss_tgrad
+ value_opt.zero_grad()
+ vloss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ value_opt.step()
+ update_ema(value_net, value_net_ema, args.ema_momentum)
+
+ cb_credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0]
+ cb_credits.append(a_l.detach())
+
+ dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ credits.append(cb_credits[l])
+ elif credit_blend <= 0.0:
+ credits.append(dfa_credits[l])
+ else:
+ cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ credits.append(credit_blend * cb_credits[l] / cb_rms +
+ (1 - credit_blend) * dfa_credits[l] / dfa_rms)
+
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a / rms)).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [scalar_cb] Ep {epoch}: loss={total_loss/total:.4f}, acc={correct/total:.4f}")
+
+ return value_net
+
+
+def train_dfa(model, teacher, device, args):
+ """DFA baseline."""
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = args.num_classes
+ Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)]
+
+ block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=0.01) for b in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=0.01)
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for _ in range(args.steps_per_epoch):
+ x, y = generate_batch(teacher, d, num_classes, args.batch_size, device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = (e_T @ Bs[l].T).detach()
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a / rms)).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [DFA] Ep {epoch}: loss={total_loss/total:.4f}, acc={correct/total:.4f}")
+ return Bs
+
+
+# =============================================================================
+# Diagnostics — uses INDEPENDENT eval directions (check A)
+# =============================================================================
+def compute_diagnostics(model, teacher, device, method_name, args,
+ value_net=None, vector_net=None, dfa_Bs=None):
+ """
+ Compute Gamma, rho, nudging per layer.
+ IMPORTANT: perturbation_correlation uses its own freshly-sampled directions,
+ completely independent of any training directions. This ensures check A.
+ """
+ model.eval()
+ if value_net is not None:
+ value_net.eval()
+ if vector_net is not None:
+ vector_net.eval()
+
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = args.num_classes
+
+ # Use a fixed eval seed different from training
+ eval_rng = torch.Generator(device=device)
+ eval_rng.manual_seed(99999)
+
+ x = torch.randn(512, d, device=device, generator=eval_rng)
+ with torch.no_grad():
+ teacher_logits = teacher(x)
+ y = teacher_logits.argmax(dim=-1)
+ batch = x.size(0)
+
+ # BP gradients (evaluation only — never used for training)
+ h = x.detach().requires_grad_(True)
+ hiddens_bp = [h]
+ for block in model.blocks:
+ f = block(hiddens_bp[-1])
+ h_next = hiddens_bp[-1] + f
+ hiddens_bp.append(h_next)
+ logits_bp = model.out_head(hiddens_bp[-1])
+ loss_bp = F.cross_entropy(logits_bp, y)
+ grads = torch.autograd.grad(loss_bp, hiddens_bp, retain_graph=False)
+ bp_grads = {l: grads[l].detach().clone() for l in range(L + 1)}
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ results = {'bp_cosine': [], 'perturbation_rho': [], 'nudging': []}
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ if method_name == 'dfa':
+ a_l = (s @ dfa_Bs[l].T).detach()
+ elif method_name == 'scalar_cb':
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach()
+ elif method_name.startswith('vec'):
+ a_l = vector_net(h_l, t_l, s).detach()
+ else:
+ raise ValueError(f"Unknown: {method_name}")
+
+ bp_cos = cosine_similarity_batch(a_l, bp_grads[l])
+ results['bp_cosine'].append(float(bp_cos))
+
+ # perturbation_correlation uses its own random directions internally
+ # (from metrics/credit_metrics.py — independent of training directions)
+ def make_fwd_fn(start_l):
+ def fwd_fn(h):
+ with torch.no_grad():
+ logits = model.forward_from_layer(h, start_l)
+ return F.cross_entropy(logits, y, reduction='none')
+ return fwd_fn
+
+ fwd_fn = make_fwd_fn(l)
+ rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=32)
+ results['perturbation_rho'].append(float(rho))
+
+ nud = nudging_test(h_l, a_l, fwd_fn, eta=0.003)
+ results['nudging'].append(float(nud))
+
+ return results
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ all_results = []
+
+ for L in args.depths:
+ for seed in args.seeds:
+ print(f"\n{'='*60}")
+ print(f"L={L}, seed={seed}")
+ print(f"{'='*60}")
+
+ teacher = TeacherNet(args.d_hidden, args.num_classes, L,
+ alpha=args.alpha, seed=seed * 1000).to(device)
+
+ # --- DFA ---
+ print("\n --- DFA ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_dfa = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device)
+ Bs = train_dfa(model_dfa, teacher, device, args)
+ diag = compute_diagnostics(model_dfa, teacher, device, 'dfa', args, dfa_Bs=Bs)
+ r = {'method': 'dfa', 'L': L, 'seed': seed,
+ 'mean_gamma': float(np.mean(diag['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag['perturbation_rho'])),
+ 'mean_nudge': float(np.mean(diag['nudging'])),
+ 'per_layer': diag}
+ print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}")
+ all_results.append(r)
+
+ # --- Scalar CB ---
+ print("\n --- Scalar CB ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_cb = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device)
+ vnet = train_scalar_cb(model_cb, teacher, device, args)
+ diag = compute_diagnostics(model_cb, teacher, device, 'scalar_cb', args, value_net=vnet)
+ r = {'method': 'scalar_cb', 'L': L, 'seed': seed,
+ 'mean_gamma': float(np.mean(diag['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag['perturbation_rho'])),
+ 'mean_nudge': float(np.mean(diag['nudging'])),
+ 'per_layer': diag}
+ print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}")
+ all_results.append(r)
+
+ # --- Vector Field M4 (central diff, with terminal) ---
+ print("\n --- vec_eT_M4 (central, +term) ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_v4 = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device)
+ vnet4 = train_vector_field_audit(model_v4, teacher, device, args, M=4,
+ use_terminal=True, shuffle_targets=False,
+ use_central_diff=True, tag='vec_eT_M4')
+ diag = compute_diagnostics(model_v4, teacher, device, 'vec_eT_M4', args, vector_net=vnet4)
+ r = {'method': 'vec_eT_M4', 'L': L, 'seed': seed,
+ 'mean_gamma': float(np.mean(diag['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag['perturbation_rho'])),
+ 'mean_nudge': float(np.mean(diag['nudging'])),
+ 'per_layer': diag}
+ print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}")
+ all_results.append(r)
+
+ # --- Vector Field M8 (central diff, with terminal) ---
+ if 8 in args.M_values:
+ print("\n --- vec_eT_M8 (central, +term) ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_v8 = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device)
+ vnet8 = train_vector_field_audit(model_v8, teacher, device, args, M=8,
+ use_terminal=True, shuffle_targets=False,
+ use_central_diff=True, tag='vec_eT_M8')
+ diag = compute_diagnostics(model_v8, teacher, device, 'vec_eT_M8', args, vector_net=vnet8)
+ r = {'method': 'vec_eT_M8', 'L': L, 'seed': seed,
+ 'mean_gamma': float(np.mean(diag['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag['perturbation_rho'])),
+ 'mean_nudge': float(np.mean(diag['nudging'])),
+ 'per_layer': diag}
+ print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}")
+ all_results.append(r)
+
+ # =================================================================
+ # SANITY CHECKS (only for first seed to save time, unless full mode)
+ # =================================================================
+ if seed == args.seeds[0] or args.full_audit:
+ # --- Check B: Shuffled-target control ---
+ print("\n --- vec_eT_M4_shuffleCtrl ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_shuf = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device)
+ vnet_shuf = train_vector_field_audit(model_shuf, teacher, device, args, M=4,
+ use_terminal=True, shuffle_targets=True,
+ use_central_diff=True, tag='vec_shuffleCtrl')
+ diag = compute_diagnostics(model_shuf, teacher, device, 'vec_shuffleCtrl', args, vector_net=vnet_shuf)
+ r = {'method': 'vec_eT_M4_shuffleCtrl', 'L': L, 'seed': seed,
+ 'mean_gamma': float(np.mean(diag['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag['perturbation_rho'])),
+ 'mean_nudge': float(np.mean(diag['nudging'])),
+ 'per_layer': diag}
+ print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}")
+ all_results.append(r)
+
+ # --- Check C: No-terminal ablation ---
+ print("\n --- vec_eT_M4_noTerm ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_nt = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device)
+ vnet_nt = train_vector_field_audit(model_nt, teacher, device, args, M=4,
+ use_terminal=False, shuffle_targets=False,
+ use_central_diff=True, tag='vec_noTerm')
+ diag = compute_diagnostics(model_nt, teacher, device, 'vec_noTerm', args, vector_net=vnet_nt)
+ r = {'method': 'vec_eT_M4_noTerm', 'L': L, 'seed': seed,
+ 'mean_gamma': float(np.mean(diag['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag['perturbation_rho'])),
+ 'mean_nudge': float(np.mean(diag['nudging'])),
+ 'per_layer': diag}
+ print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}")
+ all_results.append(r)
+
+ # --- Check D: One-sided difference ---
+ print("\n --- vec_eT_M4_onesided ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_os = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device)
+ vnet_os = train_vector_field_audit(model_os, teacher, device, args, M=4,
+ use_terminal=True, shuffle_targets=False,
+ use_central_diff=False, tag='vec_onesided')
+ diag = compute_diagnostics(model_os, teacher, device, 'vec_onesided', args, vector_net=vnet_os)
+ r = {'method': 'vec_eT_M4_onesided', 'L': L, 'seed': seed,
+ 'mean_gamma': float(np.mean(diag['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag['perturbation_rho'])),
+ 'mean_nudge': float(np.mean(diag['nudging'])),
+ 'per_layer': diag}
+ print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}")
+ all_results.append(r)
+
+ # =================================================================
+ # Summary
+ # =================================================================
+ print(f"\n{'='*80}")
+ print("AUDIT SUMMARY")
+ print(f"{'='*80}")
+ print(f"{'Method':<30} {'L':>3} {'seed':>5} {'Gamma':>8} {'rho':>8} {'nudge':>10}")
+ print("-" * 70)
+ for r in all_results:
+ print(f"{r['method']:<30} {r['L']:>3} {r['seed']:>5} "
+ f"{r['mean_gamma']:>8.4f} {r['mean_rho']:>8.4f} {r['mean_nudge']:>10.6f}")
+
+ # Check verdicts
+ print(f"\n{'='*60}")
+ print("SANITY CHECK VERDICTS")
+ print(f"{'='*60}")
+
+ for L in args.depths:
+ seed0 = args.seeds[0]
+ vec_main = [r for r in all_results if r['method'] == 'vec_eT_M4' and r['L'] == L and r['seed'] == seed0]
+ scalar_cb = [r for r in all_results if r['method'] == 'scalar_cb' and r['L'] == L and r['seed'] == seed0]
+ shuf = [r for r in all_results if r['method'] == 'vec_eT_M4_shuffleCtrl' and r['L'] == L and r['seed'] == seed0]
+ noterm = [r for r in all_results if r['method'] == 'vec_eT_M4_noTerm' and r['L'] == L and r['seed'] == seed0]
+ onesided = [r for r in all_results if r['method'] == 'vec_eT_M4_onesided' and r['L'] == L and r['seed'] == seed0]
+
+ if not vec_main or not scalar_cb:
+ continue
+ v = vec_main[0]
+ cb = scalar_cb[0]
+
+ print(f"\n L={L}:")
+ delta_gamma = v['mean_gamma'] - cb['mean_gamma']
+ delta_rho = v['mean_rho'] - cb['mean_rho']
+ print(f" vec_M4 vs scalar_cb: delta_Gamma={delta_gamma:+.4f}, delta_rho={delta_rho:+.4f}")
+
+ if shuf:
+ s = shuf[0]
+ print(f" Check B (shuffle): Gamma={s['mean_gamma']:.4f}, rho={s['mean_rho']:.4f}")
+ if s['mean_gamma'] < v['mean_gamma'] * 0.5 and s['mean_rho'] < v['mean_rho'] * 0.5:
+ print(f" -> PASS: shuffled control collapses (Gamma dropped by {v['mean_gamma']-s['mean_gamma']:.3f})")
+ else:
+ print(f" -> FAIL: shuffled control too close to main result!")
+
+ if noterm:
+ n = noterm[0]
+ print(f" Check C (noTerm): Gamma={n['mean_gamma']:.4f}, rho={n['mean_rho']:.4f}")
+ if n['mean_gamma'] < v['mean_gamma'] * 0.8:
+ print(f" -> PASS: terminal matching contributes (Gamma dropped by {v['mean_gamma']-n['mean_gamma']:.3f})")
+ else:
+ print(f" -> NOTE: terminal removal didn't collapse result. Perturbation target alone is sufficient.")
+
+ if onesided:
+ o = onesided[0]
+ print(f" Check D (onesided): Gamma={o['mean_gamma']:.4f}, rho={o['mean_rho']:.4f}")
+ if abs(o['mean_gamma'] - v['mean_gamma']) < 0.15:
+ print(f" -> PASS: one-sided ≈ central (difference = {abs(o['mean_gamma']-v['mean_gamma']):.3f})")
+ else:
+ print(f" -> NOTE: one-sided differs from central by {abs(o['mean_gamma']-v['mean_gamma']):.3f}")
+
+ # Final verdict
+ print(f"\n{'='*60}")
+ print("OVERALL AUDIT VERDICT")
+ print(f"{'='*60}")
+ all_pass = True
+ for L in args.depths:
+ for seed in args.seeds:
+ v = [r for r in all_results if r['method'] == 'vec_eT_M4' and r['L'] == L and r['seed'] == seed]
+ cb = [r for r in all_results if r['method'] == 'scalar_cb' and r['L'] == L and r['seed'] == seed]
+ if v and cb:
+ dg = v[0]['mean_gamma'] - cb[0]['mean_gamma']
+ dr = v[0]['mean_rho'] - cb[0]['mean_rho']
+ if dg < 0.2 or dr < 0.2:
+ print(f" L={L} seed={seed}: delta_Gamma={dg:.3f}, delta_rho={dr:.3f} - BELOW THRESHOLD")
+ all_pass = False
+ else:
+ print(f" L={L} seed={seed}: delta_Gamma={dg:.3f}, delta_rho={dr:.3f} - PASS")
+
+ shuf_results = [r for r in all_results if 'shuffleCtrl' in r['method']]
+ for s in shuf_results:
+ if s['mean_rho'] > 0.3:
+ print(f" SHUFFLE CONTROL WARNING: L={s['L']} rho={s['mean_rho']:.3f} too high!")
+ all_pass = False
+
+ if all_pass:
+ print("\n AUDIT PASSED. Vector field gains are real.")
+ else:
+ print("\n AUDIT FAILED or INCOMPLETE. Investigate before proceeding.")
+
+ # Save
+ save_data = []
+ for r in all_results:
+ save_r = {k: v for k, v in r.items() if k != 'per_layer'}
+ save_r['per_layer_gamma'] = r['per_layer']['bp_cosine']
+ save_r['per_layer_rho'] = r['per_layer']['perturbation_rho']
+ save_r['per_layer_nudge'] = r['per_layer']['nudging']
+ save_data.append(save_r)
+
+ out_path = os.path.join(args.output_dir, 'audit_results.json')
+ with open(out_path, 'w') as f:
+ json.dump(save_data, f, indent=2)
+ print(f"\nResults saved to {out_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 5A: Vector Credit Field Audit')
+ parser.add_argument('--d_hidden', type=int, default=128)
+ parser.add_argument('--num_classes', type=int, default=10)
+ parser.add_argument('--alpha', type=float, default=1.0)
+ parser.add_argument('--depths', type=int, nargs='+', default=[4])
+ parser.add_argument('--M_values', type=int, nargs='+', default=[4, 8])
+ parser.add_argument('--epochs', type=int, default=80)
+ parser.add_argument('--steps_per_epoch', type=int, default=50)
+ parser.add_argument('--batch_size', type=int, default=256)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--warmup_ratio', type=float, default=0.05)
+ parser.add_argument('--term_grad_weight', type=float, default=1.0)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=4)
+ parser.add_argument('--sigma_bridge', type=float, default=0.05)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--pert_eps', type=float, default=1e-3)
+ parser.add_argument('--pert_beta', type=float, default=1.0)
+ parser.add_argument('--seeds', type=int, nargs='+', default=[42])
+ parser.add_argument('--gpu', type=int, default=2)
+ parser.add_argument('--output_dir', type=str, default='results/vector_audit')
+ parser.add_argument('--full_audit', action='store_true',
+ help='Run sanity checks for all seeds (default: first seed only)')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()