From 5550e2cac45758e579810ae36bf716a0b819cebc Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 24 Mar 2026 18:03:55 -0500 Subject: Add Phase 5: vector field audit, frozen CIFAR transfer, online pilot MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 5A: Audit passes — shuffle control collapses, gains are real Phase 5B: Transfer SUCCESS — vec_M4 beats scalar CB by +0.25 Gamma, +0.31 rho on frozen CIFAR Phase 5C: Online FAILURE — vec does worse than scalar CB online despite better frozen credit Core finding: bottleneck is in local surrogate / co-adaptation, not estimator quality Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/cifar_frozen_vector_credit.py | 648 +++++++++++++++++++++++ experiments/cifar_online_vector_credit.py | 404 ++++++++++++++ experiments/vector_credit_audit.py | 844 ++++++++++++++++++++++++++++++ 3 files changed, 1896 insertions(+) create mode 100644 experiments/cifar_frozen_vector_credit.py create mode 100644 experiments/cifar_online_vector_credit.py create mode 100644 experiments/vector_credit_audit.py (limited to 'experiments') diff --git a/experiments/cifar_frozen_vector_credit.py b/experiments/cifar_frozen_vector_credit.py new file mode 100644 index 0000000..acd26e6 --- /dev/null +++ b/experiments/cifar_frozen_vector_credit.py @@ -0,0 +1,648 @@ +""" +Phase 5B: Frozen CIFAR Vector Credit Transfer. + +Test whether direct vector credit field can recover better credit than scalar CB +on frozen BP-trained CIFAR representations. + +Methods compared: +- DFA (random) +- StateBridge_eT +- ScalarCB_eT +- ScalarCB_deltaL +- VectorField_eT_M{4,8,16} +- VectorField_deltaL_M{4,8,16} (if resources allow) +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test +) + + +class VectorCreditNet(nn.Module): + """Direct vector credit field: a_phi(h_l, t_l, s) -> R^d.""" + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader + + +def evaluate(model, test_loader, device): + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + return correct / total + + +def train_bp_reference(model, train_loader, test_loader, device, epochs=100, lr=1e-3, wd=0.01): + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + for epoch in range(1, epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + loss = F.cross_entropy(logits, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + scheduler.step() + if epoch % 20 == 0 or epoch == 1: + test_acc = evaluate(model, test_loader, device) + print(f" [BP ref] Ep {epoch}: loss={total_loss/total:.4f}, test={test_acc:.4f}") + test_acc = evaluate(model, test_loader, device) + print(f" [BP ref] Final: {test_acc:.4f}") + return test_acc + + +# ============================================================================= +# Estimator training functions (all on frozen model) +# ============================================================================= + +def train_state_bridge_frozen(model, train_loader, device, epochs, lr_fb): + d = model.d_hidden + L = model.num_blocks + state_pred = StateBridgeNet(d_hidden=d, s_dim=10, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + state_opt = optim.Adam(state_pred.parameters(), lr=lr_fb) + model.eval() + for epoch in range(1, epochs + 1): + state_pred.train() + total_loss, n = 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + hL_det = hiddens[-1].detach() + state_loss = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + pred_hL = state_pred(h_l_det, t_l, s) + target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0) + state_loss += (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean() + state_loss /= L + state_opt.zero_grad() + state_loss.backward() + state_opt.step() + total_loss += state_loss.item() * batch + n += batch + if epoch % 20 == 0 or epoch == 1: + print(f" [SB] Ep {epoch}: loss={total_loss/n:.6f}") + return state_pred + + +def train_scalar_cb_frozen(model, train_loader, device, epochs, lr_fb, s_type='eT', + lam=0.1, K=4, sigma_bridge=0.05, ema_momentum=0.995, + term_grad_weight=1.0): + d = model.d_hidden + L = model.num_blocks + s_dim = 10 if s_type == 'eT' else d + value_net = ValueNet(d_hidden=d, s_dim=s_dim, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + value_net_ema = create_ema_model(value_net) + value_opt = optim.Adam(value_net.parameters(), lr=lr_fb) + model.eval() + for epoch in range(1, epochs + 1): + value_net.train() + total_vloss, n = 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + true_loss = F.cross_entropy(logits, y, reduction='none').detach() + hL_det = hiddens[-1].detach() + if s_type == 'eT': + s = e_T.detach() + else: + hL_req = hL_det.clone().requires_grad_(True) + logits_for_s = model.out_head(model.out_ln(hL_req)) + ce_for_s = F.cross_entropy(logits_for_s, y, reduction='sum') + s = torch.autograd.grad(ce_for_s, hL_req, create_graph=False)[0].detach() + + t_L = torch.ones(batch, device=device) + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + loss_tgrad = torch.tensor(0.0, device=device) + if term_grad_weight > 0: + hL_req = hL_det.clone().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + hL_req2 = hL_det.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req2)) + ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') + a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() + loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() + + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + t_next = torch.full((batch,), (l + 1) / L, device=device) + V_l = value_net(h_l_det, t_l, s) + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K): + noise = sigma_bridge * torch.randn_like(h_next_det) + V_next = value_net_ema(h_next_det + noise, t_next, s) + log_terms.append(-V_next / lam) + log_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K)) + loss_bridge += ((V_l - V_target.detach()) ** 2).mean() + loss_bridge /= L + + vloss = loss_term + loss_bridge + term_grad_weight * loss_tgrad + value_opt.zero_grad() + vloss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + value_opt.step() + update_ema(value_net, value_net_ema, ema_momentum) + total_vloss += vloss.item() * batch + n += batch + if epoch % 20 == 0 or epoch == 1: + print(f" [CB_{s_type}] Ep {epoch}: vloss={total_vloss/n:.6f}") + return value_net + + +def train_vector_field_frozen(model, train_loader, device, epochs, lr_fb, + s_type='eT', M=4, eps=1e-3, beta=1.0, + term_weight=1.0): + """ + Train vector credit field on frozen CIFAR features. + Layer subsampling: each batch, randomly pick one layer for perturbation target. + Terminal matching always uses layer L. + """ + d = model.d_hidden + L = model.num_blocks + s_dim = 10 if s_type == 'eT' else d + + vector_net = VectorCreditNet(d_hidden=d, s_dim=s_dim, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + vec_opt = optim.Adam(vector_net.parameters(), lr=lr_fb) + model.eval() + + for epoch in range(1, epochs + 1): + vector_net.train() + total_vloss, n = 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + + hL_det = hiddens[-1].detach() + + # Compute s + if s_type == 'eT': + s = e_T.detach() + else: + hL_req = hL_det.clone().requires_grad_(True) + logits_for_s = model.out_head(model.out_ln(hL_req)) + ce_for_s = F.cross_entropy(logits_for_s, y, reduction='sum') + s = torch.autograd.grad(ce_for_s, hL_req, create_graph=False)[0].detach() + + # Terminal matching + loss_term = torch.tensor(0.0, device=device) + if term_weight > 0: + t_L = torch.ones(batch, device=device) + a_terminal = vector_net(hL_det, t_L, s) + hL_req = hL_det.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req)) + ce = F.cross_entropy(logits_tgt, y, reduction='sum') + delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() + loss_term = ((a_terminal - delta_L) ** 2).sum(dim=-1).mean() + + # Perturbation target — subsample 1 random layer per batch + l = np.random.randint(0, L) + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + a_l = vector_net(h_l_det, t_l, s) + + loss_proj = torch.tensor(0.0, device=device) + for _ in range(M): + v = torch.randn_like(h_l_det) + v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) + + with torch.no_grad(): + # Use model.forward_from_layer for tail forward + logits_plus = model.forward_from_layer(h_l_det + eps * v, l) + loss_plus = F.cross_entropy(logits_plus, y, reduction='none') + logits_minus = model.forward_from_layer(h_l_det - eps * v, l) + loss_minus = F.cross_entropy(logits_minus, y, reduction='none') + g_j = (loss_plus - loss_minus) / (2 * eps) + + pred_j = (a_l * v).sum(dim=-1) + loss_proj = loss_proj + ((pred_j - g_j.detach()) ** 2).mean() + loss_proj = loss_proj / M + + vloss = term_weight * loss_term + beta * loss_proj + vec_opt.zero_grad() + vloss.backward() + torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0) + vec_opt.step() + total_vloss += vloss.item() * batch + n += batch + + if epoch % 20 == 0 or epoch == 1: + print(f" [vec_{s_type}_M{M}] Ep {epoch}: vloss={total_vloss/n:.6f}") + + return vector_net + + +# ============================================================================= +# Evaluation +# ============================================================================= +def evaluate_all(model, test_loader, device, estimators): + """Evaluate credit quality for all estimators on frozen features.""" + model.eval() + d = model.d_hidden + L = model.num_blocks + + # DFA baseline + dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] + + # Use multiple test batches for robust Gamma, single batch for rho/nudge (expensive) + results = {} + for name in list(estimators.keys()) + ['dfa']: + results[name] = { + 'bp_cosine': [[] for _ in range(L)], + 'perturbation_rho': [0.0] * L, + 'nudging_0.001': [0.0] * L, + 'nudging_0.003': [0.0] * L, + 'nudging_0.01': [0.0] * L, + } + + n_batches = min(10, len(test_loader)) + batch_idx = 0 + + for x, y in test_loader: + if batch_idx >= n_batches: + break + batch_idx += 1 + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + # BP gradients + for p in model.parameters(): + p.requires_grad_(True) + model.zero_grad() + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L + 1): + hiddens_bp[l].retain_grad() + loss_bp = F.cross_entropy(logits_bp, y) + loss_bp.backward() + bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} + for p in model.parameters(): + p.requires_grad_(False) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s_eT = e_T.detach() + + hL_det = hiddens[-1].detach() + hL_req = hL_det.clone().requires_grad_(True) + logits_delta = model.out_head(model.out_ln(hL_req)) + ce_delta = F.cross_entropy(logits_delta, y, reduction='sum') + delta_L = torch.autograd.grad(ce_delta, hL_req, create_graph=False)[0].detach() + + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + + def make_fwd_fn(start_l): + def fwd_fn(h): + with torch.no_grad(): + curr = h + for i in range(start_l, L): + curr = curr + model.blocks[i](curr) + out = model.out_head(model.out_ln(curr)) + return F.cross_entropy(out, y, reduction='none') + return fwd_fn + fwd_fn = make_fwd_fn(l) + + # DFA + a_dfa = (s_eT @ dfa_Bs[l].T).detach() + results['dfa']['bp_cosine'][l].append(cosine_similarity_batch(a_dfa, bp_grads[l])) + if batch_idx == 1: + results['dfa']['perturbation_rho'][l] = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=32) + for eta in [0.001, 0.003, 0.01]: + results['dfa'][f'nudging_{eta}'][l] = nudging_test(h_l, a_dfa, fwd_fn, eta=eta) + + # Other estimators + for name, est in estimators.items(): + if est['type'] == 'sb': + net = est['net'] + net.eval() + h_l_req = h_l.clone().requires_grad_(True) + pred_hL = net(h_l_req, t_l, s_eT) + pred_logits = model.out_head(model.out_ln(pred_hL)) + pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') + a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() + elif est['type'] == 'cb': + net = est['net'] + net.eval() + s = s_eT if est['s_type'] == 'eT' else delta_L + h_l_req = h_l.clone().requires_grad_(True) + V_l = net(h_l_req, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() + elif est['type'] == 'vec': + net = est['net'] + net.eval() + s = s_eT if est['s_type'] == 'eT' else delta_L + a_l = net(h_l, t_l, s).detach() + + results[name]['bp_cosine'][l].append(cosine_similarity_batch(a_l, bp_grads[l])) + if batch_idx == 1: + results[name]['perturbation_rho'][l] = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=32) + for eta in [0.001, 0.003, 0.01]: + results[name][f'nudging_{eta}'][l] = nudging_test(h_l, a_l, fwd_fn, eta=eta) + + # Average bp_cosine + for name in results: + for l in range(L): + vals = results[name]['bp_cosine'][l] + results[name]['bp_cosine'][l] = float(np.mean(vals)) if vals else 0.0 + return results + + +# ============================================================================= +# Main +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + train_loader, test_loader = get_cifar10(batch_size=args.batch_size) + input_dim = 32 * 32 * 3 + + # Step 1: Load/train BP reference + bp_ckpt = os.path.join(args.output_dir, f'bp_ref_L{args.num_blocks}_d{args.d_hidden}_s{args.seed}.pt') + model = ResidualMLP(input_dim, args.d_hidden, 10, args.num_blocks).to(device) + + # Try loading from frozen_cifar directory first + alt_ckpt = f'results/frozen_cifar/bp_ref_L{args.num_blocks}_d{args.d_hidden}_s{args.seed}.pt' + if os.path.exists(alt_ckpt) and not args.retrain_bp: + print(f" Loading BP ref from {alt_ckpt}") + model.load_state_dict(torch.load(alt_ckpt, map_location=device)) + bp_acc = evaluate(model, test_loader, device) + elif os.path.exists(bp_ckpt) and not args.retrain_bp: + print(f" Loading BP ref from {bp_ckpt}") + model.load_state_dict(torch.load(bp_ckpt, map_location=device)) + bp_acc = evaluate(model, test_loader, device) + else: + bp_acc = train_bp_reference(model, train_loader, test_loader, device, + epochs=args.bp_epochs, lr=args.lr, wd=args.wd) + torch.save(model.state_dict(), bp_ckpt) + print(f" BP ref acc: {bp_acc:.4f}") + + model.eval() + for p in model.parameters(): + p.requires_grad_(False) + + L = args.num_blocks + d = args.d_hidden + + # Step 2: Train estimators + print(f"\n{'='*60}") + print(f"Training estimators (L={L}, d={d}, {args.estimator_epochs} epochs)") + print(f"{'='*60}") + + estimators = {} + + # StateBridge_eT + print("\n--- StateBridge_eT ---") + torch.manual_seed(args.seed + 1000) + sb = train_state_bridge_frozen(model, train_loader, device, args.estimator_epochs, args.lr_fb) + estimators['sb_eT'] = {'type': 'sb', 'net': sb, 's_type': 'eT'} + + # ScalarCB_eT + print("\n--- ScalarCB_eT ---") + torch.manual_seed(args.seed + 2000) + cb_eT = train_scalar_cb_frozen(model, train_loader, device, args.estimator_epochs, args.lr_fb, + s_type='eT', term_grad_weight=args.term_grad_weight) + estimators['cb_eT'] = {'type': 'cb', 'net': cb_eT, 's_type': 'eT'} + + # ScalarCB_deltaL + print("\n--- ScalarCB_deltaL ---") + torch.manual_seed(args.seed + 3000) + cb_dL = train_scalar_cb_frozen(model, train_loader, device, args.estimator_epochs, args.lr_fb, + s_type='deltaL', term_grad_weight=args.term_grad_weight) + estimators['cb_deltaL'] = {'type': 'cb', 'net': cb_dL, 's_type': 'deltaL'} + + # Vector fields + for M in args.M_values: + for s_type in args.vec_s_types: + tag = f'vec_{s_type}_M{M}' + print(f"\n--- {tag} ---") + torch.manual_seed(args.seed + 4000 + M * 100 + (0 if s_type == 'eT' else 1)) + vnet = train_vector_field_frozen(model, train_loader, device, + args.estimator_epochs, args.lr_fb, + s_type=s_type, M=M, eps=args.pert_eps, + beta=args.pert_beta, term_weight=args.term_weight_vec) + estimators[tag] = {'type': 'vec', 'net': vnet, 's_type': s_type} + + # Step 3: Evaluate + print(f"\n{'='*60}") + print("Evaluating credit quality") + print(f"{'='*60}") + results = evaluate_all(model, test_loader, device, estimators) + + # Print summary + all_methods = ['dfa', 'sb_eT', 'cb_eT', 'cb_deltaL'] + \ + [f'vec_{st}_M{M}' for M in args.M_values for st in args.vec_s_types] + labels = { + 'dfa': 'DFA', 'sb_eT': 'StateBridge_eT', + 'cb_eT': 'ScalarCB_eT', 'cb_deltaL': 'ScalarCB_deltaL', + } + for M in args.M_values: + for st in args.vec_s_types: + labels[f'vec_{st}_M{M}'] = f'Vec_{st}_M{M}' + + print(f"\n{'Method':<25} {'Gamma':>8} {'rho':>8} {'nudge':>10}") + print("-" * 55) + + summary = {} + for m in all_methods: + if m not in results: + continue + r = results[m] + mg = np.mean(r['bp_cosine']) + mr = np.mean(r['perturbation_rho']) + mn = np.mean(r['nudging_0.003']) + summary[m] = {'mean_gamma': float(mg), 'mean_rho': float(mr), 'mean_nudge': float(mn)} + print(f"{labels.get(m, m):<25} {mg:>8.4f} {mr:>8.4f} {mn:>10.6f}") + + # Per-layer detail + print(f"\n--- Per-layer Gamma ---") + for l in range(L): + row = f" L{l}: " + for m in all_methods: + if m in results: + row += f" {results[m]['bp_cosine'][l]:>8.4f}" + print(row) + + print(f"\n--- Per-layer rho ---") + for l in range(L): + row = f" L{l}: " + for m in all_methods: + if m in results: + row += f" {results[m]['perturbation_rho'][l]:>8.4f}" + print(row) + + # Save + save_data = { + 'config': { + 'num_blocks': L, 'd_hidden': d, 'seed': args.seed, + 'bp_acc': float(bp_acc), 'estimator_epochs': args.estimator_epochs, + }, + 'summary': summary, + 'per_layer': {m: { + 'bp_cosine': results[m]['bp_cosine'], + 'perturbation_rho': results[m]['perturbation_rho'], + 'nudging_0.003': results[m]['nudging_0.003'], + } for m in all_methods if m in results}, + } + out_path = os.path.join(args.output_dir, + f'frozen_vec_L{L}_d{d}_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(save_data, f, indent=2) + print(f"\nResults saved to {out_path}") + + # Judgment + cb_eT_gamma = summary.get('cb_eT', {}).get('mean_gamma', 0) + cb_eT_rho = summary.get('cb_eT', {}).get('mean_rho', 0) + best_vec_gamma = max(summary.get(m, {}).get('mean_gamma', 0) for m in summary if m.startswith('vec_')) + best_vec_rho = max(summary.get(m, {}).get('mean_rho', 0) for m in summary if m.startswith('vec_')) + best_vec_name = max((m for m in summary if m.startswith('vec_')), + key=lambda m: summary[m]['mean_gamma'] + summary[m]['mean_rho'], + default='none') + + print(f"\n{'='*60}") + print("JUDGMENT") + print(f"{'='*60}") + print(f"ScalarCB_eT: Gamma={cb_eT_gamma:.4f}, rho={cb_eT_rho:.4f}") + print(f"Best vector ({best_vec_name}): Gamma={best_vec_gamma:.4f}, rho={best_vec_rho:.4f}") + + dg = best_vec_gamma - cb_eT_gamma + dr = best_vec_rho - cb_eT_rho + print(f"Delta: Gamma={dg:+.4f}, rho={dr:+.4f}") + + if dg >= 0.05 and dr >= 0.05: + print("TRANSFER SUCCESS: Vector field significantly outperforms scalar CB on frozen CIFAR.") + elif dg > 0 and dr > 0: + print("MARGINAL: Vector field slightly better, but deltas below 0.05 threshold.") + else: + print("TRANSFER FAILED: Vector field does not outperform scalar CB on frozen CIFAR.") + + return save_data + + +def main(): + parser = argparse.ArgumentParser(description='Phase 5B: Frozen CIFAR Vector Transfer') + parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--bp_epochs', type=int, default=100) + parser.add_argument('--estimator_epochs', type=int, default=100) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--term_grad_weight', type=float, default=1.0) + parser.add_argument('--term_weight_vec', type=float, default=1.0) + parser.add_argument('--pert_eps', type=float, default=1e-3) + parser.add_argument('--pert_beta', type=float, default=1.0) + parser.add_argument('--M_values', type=int, nargs='+', default=[4, 8, 16]) + parser.add_argument('--vec_s_types', type=str, nargs='+', default=['eT']) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=2) + parser.add_argument('--output_dir', type=str, default='results/frozen_cifar_vec') + parser.add_argument('--retrain_bp', action='store_true') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() diff --git a/experiments/cifar_online_vector_credit.py b/experiments/cifar_online_vector_credit.py new file mode 100644 index 0000000..3a3762c --- /dev/null +++ b/experiments/cifar_online_vector_credit.py @@ -0,0 +1,404 @@ +""" +Phase 5C: Online Shallow CIFAR Vector Credit Pilot. + +Minimal pilot: does vector field's frozen credit gain translate to online training? + +Compare DFA, ScalarCB_eT, VectorField_eT_M4 on CIFAR-10, L=4, d=256. +Sweep warmup_ratio and term_weight. +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test +) + + +class VectorCreditNet(nn.Module): + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader + + +def evaluate(model, test_loader, device): + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + return correct / total + + +def train_dfa(model, train_loader, test_loader, device, epochs, lr, wd): + d = model.d_hidden + L = model.num_blocks + Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd) + scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = {'train_loss': [], 'test_acc': []} + for epoch in range(1, epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + hL = hiddens[-1].detach() + loss_out = F.cross_entropy(model.out_head(model.out_ln(hL)), y) + head_opt.zero_grad(); loss_out.backward(); head_opt.step() + for l in range(L): + a = (e_T @ Bs[l].T).detach() + rms = (a**2).mean(-1, keepdim=True).sqrt() + 1e-6 + f = model.blocks[l](hiddens[l].detach()) + ll = (f * (a/rms)).sum(-1).mean() + block_opts[l].zero_grad(); ll.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0**2).mean(-1, keepdim=True).sqrt() + 1e-6 + el = (model.embed(x) * (a0/rms0)).sum(-1).mean() + embed_opt.zero_grad(); el.backward(); embed_opt.step() + total_loss += loss_val.item() * batch; correct += (logits.argmax(1) == y).sum().item(); total += batch + for s in scheds: s.step() + test_acc = evaluate(model, test_loader, device) + log['train_loss'].append(total_loss/total); log['test_acc'].append(test_acc) + if epoch % 20 == 0 or epoch == 1: + print(f" [DFA] Ep {epoch}: loss={total_loss/total:.4f}, test={test_acc:.4f}") + return log, Bs + + +def train_vector_online(model, train_loader, test_loader, device, epochs, lr, lr_fb, wd, + M=4, warmup_ratio=0.2, term_weight=1.0, eps=1e-3, beta=1.0): + d = model.d_hidden + L = model.num_blocks + warmup_epochs = max(1, int(epochs * warmup_ratio)) + + vector_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] + + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd) + vec_opt = optim.Adam(vector_net.parameters(), lr=lr_fb) + scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + + log = {'train_loss': [], 'test_acc': [], 'vloss': []} + + for epoch in range(1, epochs + 1): + model.train(); vector_net.train() + credit_blend = 0.0 if epoch <= warmup_epochs else min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) + total_loss, correct, total, total_vloss = 0, 0, 0, 0 + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + hL = hiddens[-1].detach() + + # Train vector net: terminal matching + loss_term = torch.tensor(0.0, device=device) + if term_weight > 0: + t_L = torch.ones(batch, device=device) + a_term = vector_net(hL, t_L, s) + hL_req = hL.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req)) + ce = F.cross_entropy(logits_tgt, y, reduction='sum') + delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() + loss_term = ((a_term - delta_L) ** 2).sum(-1).mean() + + # Perturbation target: subsample 1 layer + l_train = np.random.randint(0, L) + h_l = hiddens[l_train].detach() + t_l = torch.full((batch,), l_train / L, device=device) + a_l = vector_net(h_l, t_l, s) + + loss_proj = torch.tensor(0.0, device=device) + for _ in range(M): + v = torch.randn_like(h_l) + v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) + with torch.no_grad(): + lp = F.cross_entropy(model.forward_from_layer(h_l + eps*v, l_train), y, reduction='none') + lm = F.cross_entropy(model.forward_from_layer(h_l - eps*v, l_train), y, reduction='none') + g_j = (lp - lm) / (2*eps) + loss_proj = loss_proj + (((a_l * v).sum(-1) - g_j.detach())**2).mean() + loss_proj = loss_proj / M + + vloss = term_weight * loss_term + beta * loss_proj + vec_opt.zero_grad(); vloss.backward() + torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0) + vec_opt.step() + total_vloss += vloss.item() * batch + + # Compute credits + with torch.no_grad(): + vec_credits = [vector_net(hiddens[l].detach(), + torch.full((batch,), l/L, device=device), s).detach() for l in range(L)] + dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)] + + credits = [] + for l in range(L): + if credit_blend >= 1.0: + credits.append(vec_credits[l]) + elif credit_blend <= 0.0: + credits.append(dfa_credits[l]) + else: + vr = (vec_credits[l]**2).mean(-1, keepdim=True).sqrt() + 1e-6 + dr = (dfa_credits[l]**2).mean(-1, keepdim=True).sqrt() + 1e-6 + credits.append(credit_blend * vec_credits[l]/vr + (1-credit_blend) * dfa_credits[l]/dr) + + # Update head + loss_out = F.cross_entropy(model.out_head(model.out_ln(hL)), y) + head_opt.zero_grad(); loss_out.backward(); head_opt.step() + + # Update blocks + for l in range(L): + a = credits[l] + rms = (a**2).mean(-1, keepdim=True).sqrt() + 1e-6 + f = model.blocks[l](hiddens[l].detach()) + ll = (f * (a/rms)).sum(-1).mean() + block_opts[l].zero_grad(); ll.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + # Update embedding + a0 = credits[0] + rms0 = (a0**2).mean(-1, keepdim=True).sqrt() + 1e-6 + el = (model.embed(x) * (a0/rms0)).sum(-1).mean() + embed_opt.zero_grad(); el.backward(); embed_opt.step() + + total_loss += loss_val.item()*batch; correct += (logits.argmax(1)==y).sum().item(); total += batch + + for s in scheds: s.step() + test_acc = evaluate(model, test_loader, device) + log['train_loss'].append(total_loss/total); log['test_acc'].append(test_acc) + log['vloss'].append(total_vloss/total) + if epoch % 20 == 0 or epoch == 1: + phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}" + print(f" [vec_M{M}] Ep {epoch} ({phase}): loss={total_loss/total:.4f}, test={test_acc:.4f}") + + return log, vector_net + + +def compute_diagnostics(model, test_loader, device, method_name, value_net=None, vector_net=None, dfa_Bs=None): + model.eval() + if value_net: value_net.eval() + if vector_net: vector_net.eval() + L = model.num_blocks + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device); break + batch = x.size(0) + + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L+1): hiddens_bp[l].retain_grad() + F.cross_entropy(logits_bp, y).backward() + bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L+1)} + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1; s = e_T.detach() + + results = {'bp_cosine': [], 'perturbation_rho': [], 'nudging': {'0.001':[], '0.003':[], '0.01':[]}} + for l in range(L): + h_l = hiddens[l].detach(); t_l = torch.full((batch,), l/L, device=device) + if method_name == 'dfa': + a_l = (s @ dfa_Bs[l].T).detach() + elif method_name.startswith('vec'): + a_l = vector_net(h_l, t_l, s).detach() + results['bp_cosine'].append(float(cosine_similarity_batch(a_l, bp_grads[l]))) + def make_fwd(sl): + def f(h): + with torch.no_grad(): + c=h + for i in range(sl,L): c=c+model.blocks[i](c) + return F.cross_entropy(model.out_head(model.out_ln(c)),y,reduction='none') + return f + fwd = make_fwd(l) + results['perturbation_rho'].append(float(perturbation_correlation(h_l, a_l, fwd, epsilon=1e-3, M=16))) + for eta in [0.001, 0.003, 0.01]: + results['nudging'][str(eta)].append(float(nudging_test(h_l, a_l, fwd, eta=eta))) + return results + + +def run_config(L, d, method, seed, train_loader, test_loader, device, + epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01, + M=4, warmup_ratio=0.2, term_weight=1.0, eps=1e-3, beta=1.0): + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + model = ResidualMLP(32*32*3, d, 10, L).to(device) + config_str = f"L={L}, d={d}, {method}, s={seed}" + if method.startswith('vec'): config_str += f", wr={warmup_ratio}, tw={term_weight}" + print(f"\n --- {config_str} ---") + + if method == 'dfa': + log, Bs = train_dfa(model, train_loader, test_loader, device, epochs, lr, wd) + diag = compute_diagnostics(model, test_loader, device, 'dfa', dfa_Bs=Bs) + elif method.startswith('vec'): + log, vnet = train_vector_online(model, train_loader, test_loader, device, + epochs, lr, lr_fb, wd, M=M, + warmup_ratio=warmup_ratio, term_weight=term_weight, + eps=eps, beta=beta) + diag = compute_diagnostics(model, test_loader, device, 'vec', vector_net=vnet) + + result = { + 'method': method, 'L': L, 'd': d, 'seed': seed, + 'warmup_ratio': warmup_ratio, 'term_weight': term_weight, 'M': M, + 'test_acc': log['test_acc'][-1], + 'mean_gamma': float(np.mean(diag['bp_cosine'])), + 'mean_rho': float(np.mean(diag['perturbation_rho'])), + 'mean_nudge': float(np.mean(diag['nudging']['0.003'])), + 'per_layer_gamma': diag['bp_cosine'], + 'per_layer_rho': diag['perturbation_rho'], + } + print(f" Result: acc={result['test_acc']:.4f}, Gamma={result['mean_gamma']:.4f}, " + f"rho={result['mean_rho']:.4f}, nudge={result['mean_nudge']:.6f}") + return result + + +def main(): + parser = argparse.ArgumentParser(description='Phase 5C: Online CIFAR Vector Pilot') + parser.add_argument('--L', type=int, default=4) + parser.add_argument('--d', type=int, default=256) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--M', type=int, default=4) + parser.add_argument('--warmup_ratios', type=float, nargs='+', default=[0.0, 0.05, 0.2]) + parser.add_argument('--term_weights', type=float, nargs='+', default=[1.0, 4.0]) + parser.add_argument('--pert_eps', type=float, default=1e-3) + parser.add_argument('--pert_beta', type=float, default=1.0) + parser.add_argument('--seeds', type=int, nargs='+', default=[42]) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--gpu', type=int, default=2) + parser.add_argument('--output_dir', type=str, default='results/online_vec_pilot') + args = parser.parse_args() + + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + train_loader, test_loader = get_cifar10(args.batch_size) + + all_results = [] + + for seed in args.seeds: + # DFA baseline + r = run_config(args.L, args.d, 'dfa', seed, train_loader, test_loader, device, + args.epochs, args.lr, args.lr_fb, args.wd) + all_results.append(r) + + # Vector field sweep + for wr in args.warmup_ratios: + for tw in args.term_weights: + r = run_config(args.L, args.d, 'vec_eT_M4', seed, train_loader, test_loader, device, + args.epochs, args.lr, args.lr_fb, args.wd, + M=args.M, warmup_ratio=wr, term_weight=tw, + eps=args.pert_eps, beta=args.pert_beta) + all_results.append(r) + + # Summary + dfa_baselines = {r['seed']: r for r in all_results if r['method'] == 'dfa'} + print(f"\n{'='*90}") + print("SUMMARY") + print(f"{'='*90}") + print(f"{'Method':<20} {'seed':>5} {'wr':>5} {'tw':>5} {'Acc':>6} {'Gamma':>7} {'rho':>7} {'nudge':>10} {'S1':>7} {'S2':>7}") + print("-" * 90) + + positive = [] + for r in all_results: + dfa = dfa_baselines.get(r['seed'], {}) + S1 = r['mean_gamma'] - dfa.get('mean_gamma', 0) + S2 = r['mean_rho'] - dfa.get('mean_rho', 0) + wr_s = f"{r.get('warmup_ratio', '-'):>5.2f}" if r['method'] != 'dfa' else " -" + tw_s = f"{r.get('term_weight', '-'):>5.1f}" if r['method'] != 'dfa' else " -" + print(f"{r['method']:<20} {r['seed']:>5} {wr_s} {tw_s} {r['test_acc']:>6.4f} " + f"{r['mean_gamma']:>7.4f} {r['mean_rho']:>7.4f} {r['mean_nudge']:>10.6f} {S1:>7.4f} {S2:>7.4f}") + if r['method'] != 'dfa' and S1 > 0 and S2 > 0: + nb = r['mean_nudge'] < dfa.get('mean_nudge', 0) + positive.append({**r, 'S1': S1, 'S2': S2, 'nudge_better': nb}) + + if positive: + print(f"\nPOSITIVE CONFIGS (S1>0 AND S2>0):") + for p in positive: + print(f" {p['method']} wr={p['warmup_ratio']} tw={p['term_weight']}: " + f"S1={p['S1']:.4f} S2={p['S2']:.4f} nudge_better={p['nudge_better']}") + else: + print(f"\nNO POSITIVE CONFIGS.") + + out_path = os.path.join(args.output_dir, f'pilot_s{args.seeds[0]}.json') + with open(out_path, 'w') as f: + json.dump(all_results, f, indent=2) + print(f"\nSaved to {out_path}") + + +if __name__ == '__main__': + main() diff --git a/experiments/vector_credit_audit.py b/experiments/vector_credit_audit.py new file mode 100644 index 0000000..048efb7 --- /dev/null +++ b/experiments/vector_credit_audit.py @@ -0,0 +1,844 @@ +""" +Phase 5A: Vector Credit Field Audit. + +Verify that the vector field's gains are real, not implementation artifacts. + +4 mandatory sanity checks: +A. Train/eval direction split (independent random directions) +B. Shuffled-target control (permute g_j within batch) +C. No-terminal ablation (L_term = 0) +D. One-sided vs symmetric finite difference +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test +) + + +# ============================================================================= +# Synthetic teacher-student +# ============================================================================= +class TeacherNet(nn.Module): + def __init__(self, d_hidden, num_classes, num_blocks, alpha=1.0, seed=0): + super().__init__() + self.d_hidden = d_hidden + self.num_blocks = num_blocks + self.alpha = alpha + rng = torch.Generator().manual_seed(seed) + self.Ws = nn.ParameterList() + for _ in range(num_blocks): + W = torch.randn(d_hidden, d_hidden, generator=rng) * 0.3 / (d_hidden ** 0.5) + U, S, Vh = torch.linalg.svd(W, full_matrices=False) + S_clamped = S.clamp(max=0.3) + W = U @ torch.diag(S_clamped) @ Vh + self.Ws.append(nn.Parameter(W, requires_grad=False)) + self.U = nn.Parameter( + torch.randn(num_classes, d_hidden, generator=rng) / (d_hidden ** 0.5), + requires_grad=False) + + def phi(self, z): + return (1 - self.alpha) * z + self.alpha * torch.tanh(z) + + def forward(self, x): + h = x + for W in self.Ws: + h = h + self.phi(h @ W.T) + return h @ self.U.T + + +class StudentBlock(nn.Module): + def __init__(self, d_hidden, alpha=1.0): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.w = nn.Linear(d_hidden, d_hidden, bias=False) + nn.init.normal_(self.w.weight, std=0.01) + self.alpha = alpha + + def phi(self, z): + return (1 - self.alpha) * z + self.alpha * torch.tanh(z) + + def forward(self, h): + return self.w(self.phi(self.ln(h))) + + +class StudentNet(nn.Module): + def __init__(self, d_hidden, num_classes, num_blocks, alpha=1.0): + super().__init__() + self.blocks = nn.ModuleList([StudentBlock(d_hidden, alpha) for _ in range(num_blocks)]) + self.out_head = nn.Linear(d_hidden, num_classes) + self.d_hidden = d_hidden + self.num_blocks = num_blocks + + def forward(self, x, return_hidden=False): + h = x + hiddens = [h] if return_hidden else None + for block in self.blocks: + f = block(h) + h = h + f + if return_hidden: + hiddens.append(h) + logits = self.out_head(h) + if return_hidden: + return logits, hiddens + return logits + + def forward_from_layer(self, h, start_layer): + for i in range(start_layer, self.num_blocks): + f = self.blocks[i](h) + h = h + f + return self.out_head(h) + + +class VectorCreditNet(nn.Module): + """Direct vector credit field: a_phi(h_l, t_l, s) -> R^d.""" + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) + + +def generate_batch(teacher, d_hidden, num_classes, batch_size, device): + x = torch.randn(batch_size, d_hidden, device=device) + with torch.no_grad(): + teacher_logits = teacher(x) + y = teacher_logits.argmax(dim=-1) + return x, y + + +# ============================================================================= +# Training: vector field with audit controls +# ============================================================================= +def train_vector_field_audit(model, teacher, device, args, M=4, + use_terminal=True, + shuffle_targets=False, + use_central_diff=True, + tag='vec'): + """ + Train vector credit field with configurable audit controls. + + Args: + use_terminal: if False, L_term = 0 (no-terminal ablation) + shuffle_targets: if True, permute g_j within batch (leak check) + use_central_diff: if True, central difference; if False, one-sided + tag: label for printing + """ + d = model.d_hidden + L = model.num_blocks + num_classes = args.num_classes + + vector_net = VectorCreditNet(d_hidden=d, s_dim=num_classes, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + + Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] + + block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=0.01) for b in model.blocks] + head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=0.01) + vec_opt = optim.Adam(vector_net.parameters(), lr=args.lr_fb) + + warmup_epochs = max(1, int(args.epochs * args.warmup_ratio)) + eps = args.pert_eps + beta = args.pert_beta + + for epoch in range(1, args.epochs + 1): + model.train() + vector_net.train() + + if epoch <= warmup_epochs: + credit_blend = 0.0 + else: + credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) + + total_loss, correct, total = 0, 0, 0 + total_vloss = 0 + + for _ in range(args.steps_per_epoch): + x, y = generate_batch(teacher, d, num_classes, args.batch_size, device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + hL_det = hiddens[-1].detach() + + # --- Terminal matching --- + loss_term = torch.tensor(0.0, device=device) + if use_terminal: + t_L = torch.ones(batch, device=device) + a_terminal = vector_net(hL_det, t_L, s) + hL_req = hL_det.clone().requires_grad_(True) + logits_tgt = model.out_head(hL_req) + ce = F.cross_entropy(logits_tgt, y, reduction='sum') + delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() + loss_term = ((a_terminal - delta_L) ** 2).sum(dim=-1).mean() + + # --- Perturbation directional targets --- + # IMPORTANT: training directions are sampled fresh each step. + # Evaluation uses independently sampled directions (see compute_diagnostics). + loss_proj = torch.tensor(0.0, device=device) + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + a_l = vector_net(h_l_det, t_l, s) + + layer_proj_loss = 0.0 + for _ in range(M): + v = torch.randn_like(h_l_det) + v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) + + with torch.no_grad(): + if use_central_diff: + # Central difference: [loss(h+eps*v) - loss(h-eps*v)] / (2*eps) + logits_plus = model.forward_from_layer(h_l_det + eps * v, l) + loss_plus = F.cross_entropy(logits_plus, y, reduction='none') + logits_minus = model.forward_from_layer(h_l_det - eps * v, l) + loss_minus = F.cross_entropy(logits_minus, y, reduction='none') + g_j = (loss_plus - loss_minus) / (2 * eps) + else: + # One-sided difference: [loss(h+eps*v) - loss(h)] / eps + logits_base = model.forward_from_layer(h_l_det, l) + loss_base = F.cross_entropy(logits_base, y, reduction='none') + logits_plus = model.forward_from_layer(h_l_det + eps * v, l) + loss_plus = F.cross_entropy(logits_plus, y, reduction='none') + g_j = (loss_plus - loss_base) / eps + + # Shuffled-target control: permute g_j within batch + if shuffle_targets: + perm = torch.randperm(batch, device=device) + g_j = g_j[perm] + + pred_j = (a_l * v).sum(dim=-1) + layer_proj_loss = layer_proj_loss + ((pred_j - g_j.detach()) ** 2).mean() + + loss_proj = loss_proj + layer_proj_loss / M + loss_proj = loss_proj / L + + vec_loss = loss_term + beta * loss_proj + vec_opt.zero_grad() + vec_loss.backward() + torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0) + vec_opt.step() + total_vloss += vec_loss.item() * batch + + # --- Block updates --- + with torch.no_grad(): + vec_credits = [] + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + a_l = vector_net(h_l_det, t_l, s) + vec_credits.append(a_l.detach()) + + dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)] + + credits = [] + for l in range(L): + if credit_blend >= 1.0: + credits.append(vec_credits[l]) + elif credit_blend <= 0.0: + credits.append(dfa_credits[l]) + else: + vc_rms = (vec_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + credits.append(credit_blend * vec_credits[l] / vc_rms + + (1 - credit_blend) * dfa_credits[l] / dfa_rms) + + logits_out = model.out_head(hL_det) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](h_l) + local_loss = (f_l * (a / rms)).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + if epoch % 20 == 0 or epoch == 1: + acc = correct / total + print(f" [{tag}] Ep {epoch}: loss={total_loss/total:.4f}, acc={acc:.4f}, " + f"vloss={total_vloss/total:.6f}") + + return vector_net + + +def train_scalar_cb(model, teacher, device, args): + """Scalar credit bridge baseline.""" + d = model.d_hidden + L = model.num_blocks + num_classes = args.num_classes + + value_net = ValueNet(d_hidden=d, s_dim=num_classes, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + value_net_ema = create_ema_model(value_net) + + Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] + + block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=0.01) for b in model.blocks] + head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=0.01) + value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) + + warmup_epochs = max(1, int(args.epochs * args.warmup_ratio)) + + for epoch in range(1, args.epochs + 1): + model.train() + value_net.train() + + if epoch <= warmup_epochs: + credit_blend = 0.0 + else: + credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) + + total_loss, correct, total = 0, 0, 0 + for _ in range(args.steps_per_epoch): + x, y = generate_batch(teacher, d, num_classes, args.batch_size, device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + true_loss = F.cross_entropy(logits, y, reduction='none').detach() + + hL_det = hiddens[-1].detach() + t_L = torch.ones(batch, device=device) + V_term = value_net(hL_det, t_L, s) + loss_term = ((V_term - true_loss) ** 2).mean() + + hL_req = hL_det.clone().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + hL_req2 = hL_det.clone().requires_grad_(True) + logits_tgt = model.out_head(hL_req2) + ce = F.cross_entropy(logits_tgt, y, reduction='sum') + a_L_exact = torch.autograd.grad(ce, hL_req2, create_graph=False)[0].detach() + loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() + + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + t_next = torch.full((batch,), (l + 1) / L, device=device) + V_l = value_net(h_l_det, t_l, s) + with torch.no_grad(): + h_next = hiddens[l + 1].detach() + log_terms = [] + for k in range(args.K): + noise = args.sigma_bridge * torch.randn_like(h_next) + V_next = value_net_ema(h_next + noise, t_next, s) + log_terms.append(-V_next / args.lam) + log_stack = torch.stack(log_terms, dim=-1) + V_target = -args.lam * (torch.logsumexp(log_stack, dim=-1) - np.log(args.K)) + loss_bridge += ((V_l - V_target.detach()) ** 2).mean() + loss_bridge /= L + + vloss = loss_term + loss_bridge + args.term_grad_weight * loss_tgrad + value_opt.zero_grad() + vloss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + value_opt.step() + update_ema(value_net, value_net_ema, args.ema_momentum) + + cb_credits = [] + for l in range(L): + h_l_det = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + V_l = value_net(h_l_det, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] + cb_credits.append(a_l.detach()) + + dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)] + + credits = [] + for l in range(L): + if credit_blend >= 1.0: + credits.append(cb_credits[l]) + elif credit_blend <= 0.0: + credits.append(dfa_credits[l]) + else: + cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + credits.append(credit_blend * cb_credits[l] / cb_rms + + (1 - credit_blend) * dfa_credits[l] / dfa_rms) + + logits_out = model.out_head(hL_det) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](h_l) + local_loss = (f_l * (a / rms)).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + if epoch % 20 == 0 or epoch == 1: + print(f" [scalar_cb] Ep {epoch}: loss={total_loss/total:.4f}, acc={correct/total:.4f}") + + return value_net + + +def train_dfa(model, teacher, device, args): + """DFA baseline.""" + d = model.d_hidden + L = model.num_blocks + num_classes = args.num_classes + Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] + + block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=0.01) for b in model.blocks] + head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=0.01) + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + for _ in range(args.steps_per_epoch): + x, y = generate_batch(teacher, d, num_classes, args.batch_size, device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + hL_det = hiddens[-1].detach() + logits_out = model.out_head(hL_det) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + for l in range(L): + h_l = hiddens[l].detach() + a = (e_T @ Bs[l].T).detach() + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](h_l) + local_loss = (f_l * (a / rms)).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + if epoch % 20 == 0 or epoch == 1: + print(f" [DFA] Ep {epoch}: loss={total_loss/total:.4f}, acc={correct/total:.4f}") + return Bs + + +# ============================================================================= +# Diagnostics — uses INDEPENDENT eval directions (check A) +# ============================================================================= +def compute_diagnostics(model, teacher, device, method_name, args, + value_net=None, vector_net=None, dfa_Bs=None): + """ + Compute Gamma, rho, nudging per layer. + IMPORTANT: perturbation_correlation uses its own freshly-sampled directions, + completely independent of any training directions. This ensures check A. + """ + model.eval() + if value_net is not None: + value_net.eval() + if vector_net is not None: + vector_net.eval() + + d = model.d_hidden + L = model.num_blocks + num_classes = args.num_classes + + # Use a fixed eval seed different from training + eval_rng = torch.Generator(device=device) + eval_rng.manual_seed(99999) + + x = torch.randn(512, d, device=device, generator=eval_rng) + with torch.no_grad(): + teacher_logits = teacher(x) + y = teacher_logits.argmax(dim=-1) + batch = x.size(0) + + # BP gradients (evaluation only — never used for training) + h = x.detach().requires_grad_(True) + hiddens_bp = [h] + for block in model.blocks: + f = block(hiddens_bp[-1]) + h_next = hiddens_bp[-1] + f + hiddens_bp.append(h_next) + logits_bp = model.out_head(hiddens_bp[-1]) + loss_bp = F.cross_entropy(logits_bp, y) + grads = torch.autograd.grad(loss_bp, hiddens_bp, retain_graph=False) + bp_grads = {l: grads[l].detach().clone() for l in range(L + 1)} + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + results = {'bp_cosine': [], 'perturbation_rho': [], 'nudging': []} + + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + + if method_name == 'dfa': + a_l = (s @ dfa_Bs[l].T).detach() + elif method_name == 'scalar_cb': + h_l_req = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() + elif method_name.startswith('vec'): + a_l = vector_net(h_l, t_l, s).detach() + else: + raise ValueError(f"Unknown: {method_name}") + + bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) + results['bp_cosine'].append(float(bp_cos)) + + # perturbation_correlation uses its own random directions internally + # (from metrics/credit_metrics.py — independent of training directions) + def make_fwd_fn(start_l): + def fwd_fn(h): + with torch.no_grad(): + logits = model.forward_from_layer(h, start_l) + return F.cross_entropy(logits, y, reduction='none') + return fwd_fn + + fwd_fn = make_fwd_fn(l) + rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=32) + results['perturbation_rho'].append(float(rho)) + + nud = nudging_test(h_l, a_l, fwd_fn, eta=0.003) + results['nudging'].append(float(nud)) + + return results + + +# ============================================================================= +# Main +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + all_results = [] + + for L in args.depths: + for seed in args.seeds: + print(f"\n{'='*60}") + print(f"L={L}, seed={seed}") + print(f"{'='*60}") + + teacher = TeacherNet(args.d_hidden, args.num_classes, L, + alpha=args.alpha, seed=seed * 1000).to(device) + + # --- DFA --- + print("\n --- DFA ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_dfa = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device) + Bs = train_dfa(model_dfa, teacher, device, args) + diag = compute_diagnostics(model_dfa, teacher, device, 'dfa', args, dfa_Bs=Bs) + r = {'method': 'dfa', 'L': L, 'seed': seed, + 'mean_gamma': float(np.mean(diag['bp_cosine'])), + 'mean_rho': float(np.mean(diag['perturbation_rho'])), + 'mean_nudge': float(np.mean(diag['nudging'])), + 'per_layer': diag} + print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}") + all_results.append(r) + + # --- Scalar CB --- + print("\n --- Scalar CB ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_cb = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device) + vnet = train_scalar_cb(model_cb, teacher, device, args) + diag = compute_diagnostics(model_cb, teacher, device, 'scalar_cb', args, value_net=vnet) + r = {'method': 'scalar_cb', 'L': L, 'seed': seed, + 'mean_gamma': float(np.mean(diag['bp_cosine'])), + 'mean_rho': float(np.mean(diag['perturbation_rho'])), + 'mean_nudge': float(np.mean(diag['nudging'])), + 'per_layer': diag} + print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}") + all_results.append(r) + + # --- Vector Field M4 (central diff, with terminal) --- + print("\n --- vec_eT_M4 (central, +term) ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_v4 = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device) + vnet4 = train_vector_field_audit(model_v4, teacher, device, args, M=4, + use_terminal=True, shuffle_targets=False, + use_central_diff=True, tag='vec_eT_M4') + diag = compute_diagnostics(model_v4, teacher, device, 'vec_eT_M4', args, vector_net=vnet4) + r = {'method': 'vec_eT_M4', 'L': L, 'seed': seed, + 'mean_gamma': float(np.mean(diag['bp_cosine'])), + 'mean_rho': float(np.mean(diag['perturbation_rho'])), + 'mean_nudge': float(np.mean(diag['nudging'])), + 'per_layer': diag} + print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}") + all_results.append(r) + + # --- Vector Field M8 (central diff, with terminal) --- + if 8 in args.M_values: + print("\n --- vec_eT_M8 (central, +term) ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_v8 = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device) + vnet8 = train_vector_field_audit(model_v8, teacher, device, args, M=8, + use_terminal=True, shuffle_targets=False, + use_central_diff=True, tag='vec_eT_M8') + diag = compute_diagnostics(model_v8, teacher, device, 'vec_eT_M8', args, vector_net=vnet8) + r = {'method': 'vec_eT_M8', 'L': L, 'seed': seed, + 'mean_gamma': float(np.mean(diag['bp_cosine'])), + 'mean_rho': float(np.mean(diag['perturbation_rho'])), + 'mean_nudge': float(np.mean(diag['nudging'])), + 'per_layer': diag} + print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}") + all_results.append(r) + + # ================================================================= + # SANITY CHECKS (only for first seed to save time, unless full mode) + # ================================================================= + if seed == args.seeds[0] or args.full_audit: + # --- Check B: Shuffled-target control --- + print("\n --- vec_eT_M4_shuffleCtrl ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_shuf = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device) + vnet_shuf = train_vector_field_audit(model_shuf, teacher, device, args, M=4, + use_terminal=True, shuffle_targets=True, + use_central_diff=True, tag='vec_shuffleCtrl') + diag = compute_diagnostics(model_shuf, teacher, device, 'vec_shuffleCtrl', args, vector_net=vnet_shuf) + r = {'method': 'vec_eT_M4_shuffleCtrl', 'L': L, 'seed': seed, + 'mean_gamma': float(np.mean(diag['bp_cosine'])), + 'mean_rho': float(np.mean(diag['perturbation_rho'])), + 'mean_nudge': float(np.mean(diag['nudging'])), + 'per_layer': diag} + print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}") + all_results.append(r) + + # --- Check C: No-terminal ablation --- + print("\n --- vec_eT_M4_noTerm ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_nt = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device) + vnet_nt = train_vector_field_audit(model_nt, teacher, device, args, M=4, + use_terminal=False, shuffle_targets=False, + use_central_diff=True, tag='vec_noTerm') + diag = compute_diagnostics(model_nt, teacher, device, 'vec_noTerm', args, vector_net=vnet_nt) + r = {'method': 'vec_eT_M4_noTerm', 'L': L, 'seed': seed, + 'mean_gamma': float(np.mean(diag['bp_cosine'])), + 'mean_rho': float(np.mean(diag['perturbation_rho'])), + 'mean_nudge': float(np.mean(diag['nudging'])), + 'per_layer': diag} + print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}") + all_results.append(r) + + # --- Check D: One-sided difference --- + print("\n --- vec_eT_M4_onesided ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_os = StudentNet(args.d_hidden, args.num_classes, L, alpha=args.alpha).to(device) + vnet_os = train_vector_field_audit(model_os, teacher, device, args, M=4, + use_terminal=True, shuffle_targets=False, + use_central_diff=False, tag='vec_onesided') + diag = compute_diagnostics(model_os, teacher, device, 'vec_onesided', args, vector_net=vnet_os) + r = {'method': 'vec_eT_M4_onesided', 'L': L, 'seed': seed, + 'mean_gamma': float(np.mean(diag['bp_cosine'])), + 'mean_rho': float(np.mean(diag['perturbation_rho'])), + 'mean_nudge': float(np.mean(diag['nudging'])), + 'per_layer': diag} + print(f" Result: Gamma={r['mean_gamma']:.4f}, rho={r['mean_rho']:.4f}, nudge={r['mean_nudge']:.6f}") + all_results.append(r) + + # ================================================================= + # Summary + # ================================================================= + print(f"\n{'='*80}") + print("AUDIT SUMMARY") + print(f"{'='*80}") + print(f"{'Method':<30} {'L':>3} {'seed':>5} {'Gamma':>8} {'rho':>8} {'nudge':>10}") + print("-" * 70) + for r in all_results: + print(f"{r['method']:<30} {r['L']:>3} {r['seed']:>5} " + f"{r['mean_gamma']:>8.4f} {r['mean_rho']:>8.4f} {r['mean_nudge']:>10.6f}") + + # Check verdicts + print(f"\n{'='*60}") + print("SANITY CHECK VERDICTS") + print(f"{'='*60}") + + for L in args.depths: + seed0 = args.seeds[0] + vec_main = [r for r in all_results if r['method'] == 'vec_eT_M4' and r['L'] == L and r['seed'] == seed0] + scalar_cb = [r for r in all_results if r['method'] == 'scalar_cb' and r['L'] == L and r['seed'] == seed0] + shuf = [r for r in all_results if r['method'] == 'vec_eT_M4_shuffleCtrl' and r['L'] == L and r['seed'] == seed0] + noterm = [r for r in all_results if r['method'] == 'vec_eT_M4_noTerm' and r['L'] == L and r['seed'] == seed0] + onesided = [r for r in all_results if r['method'] == 'vec_eT_M4_onesided' and r['L'] == L and r['seed'] == seed0] + + if not vec_main or not scalar_cb: + continue + v = vec_main[0] + cb = scalar_cb[0] + + print(f"\n L={L}:") + delta_gamma = v['mean_gamma'] - cb['mean_gamma'] + delta_rho = v['mean_rho'] - cb['mean_rho'] + print(f" vec_M4 vs scalar_cb: delta_Gamma={delta_gamma:+.4f}, delta_rho={delta_rho:+.4f}") + + if shuf: + s = shuf[0] + print(f" Check B (shuffle): Gamma={s['mean_gamma']:.4f}, rho={s['mean_rho']:.4f}") + if s['mean_gamma'] < v['mean_gamma'] * 0.5 and s['mean_rho'] < v['mean_rho'] * 0.5: + print(f" -> PASS: shuffled control collapses (Gamma dropped by {v['mean_gamma']-s['mean_gamma']:.3f})") + else: + print(f" -> FAIL: shuffled control too close to main result!") + + if noterm: + n = noterm[0] + print(f" Check C (noTerm): Gamma={n['mean_gamma']:.4f}, rho={n['mean_rho']:.4f}") + if n['mean_gamma'] < v['mean_gamma'] * 0.8: + print(f" -> PASS: terminal matching contributes (Gamma dropped by {v['mean_gamma']-n['mean_gamma']:.3f})") + else: + print(f" -> NOTE: terminal removal didn't collapse result. Perturbation target alone is sufficient.") + + if onesided: + o = onesided[0] + print(f" Check D (onesided): Gamma={o['mean_gamma']:.4f}, rho={o['mean_rho']:.4f}") + if abs(o['mean_gamma'] - v['mean_gamma']) < 0.15: + print(f" -> PASS: one-sided ≈ central (difference = {abs(o['mean_gamma']-v['mean_gamma']):.3f})") + else: + print(f" -> NOTE: one-sided differs from central by {abs(o['mean_gamma']-v['mean_gamma']):.3f}") + + # Final verdict + print(f"\n{'='*60}") + print("OVERALL AUDIT VERDICT") + print(f"{'='*60}") + all_pass = True + for L in args.depths: + for seed in args.seeds: + v = [r for r in all_results if r['method'] == 'vec_eT_M4' and r['L'] == L and r['seed'] == seed] + cb = [r for r in all_results if r['method'] == 'scalar_cb' and r['L'] == L and r['seed'] == seed] + if v and cb: + dg = v[0]['mean_gamma'] - cb[0]['mean_gamma'] + dr = v[0]['mean_rho'] - cb[0]['mean_rho'] + if dg < 0.2 or dr < 0.2: + print(f" L={L} seed={seed}: delta_Gamma={dg:.3f}, delta_rho={dr:.3f} - BELOW THRESHOLD") + all_pass = False + else: + print(f" L={L} seed={seed}: delta_Gamma={dg:.3f}, delta_rho={dr:.3f} - PASS") + + shuf_results = [r for r in all_results if 'shuffleCtrl' in r['method']] + for s in shuf_results: + if s['mean_rho'] > 0.3: + print(f" SHUFFLE CONTROL WARNING: L={s['L']} rho={s['mean_rho']:.3f} too high!") + all_pass = False + + if all_pass: + print("\n AUDIT PASSED. Vector field gains are real.") + else: + print("\n AUDIT FAILED or INCOMPLETE. Investigate before proceeding.") + + # Save + save_data = [] + for r in all_results: + save_r = {k: v for k, v in r.items() if k != 'per_layer'} + save_r['per_layer_gamma'] = r['per_layer']['bp_cosine'] + save_r['per_layer_rho'] = r['per_layer']['perturbation_rho'] + save_r['per_layer_nudge'] = r['per_layer']['nudging'] + save_data.append(save_r) + + out_path = os.path.join(args.output_dir, 'audit_results.json') + with open(out_path, 'w') as f: + json.dump(save_data, f, indent=2) + print(f"\nResults saved to {out_path}") + + +def main(): + parser = argparse.ArgumentParser(description='Phase 5A: Vector Credit Field Audit') + parser.add_argument('--d_hidden', type=int, default=128) + parser.add_argument('--num_classes', type=int, default=10) + parser.add_argument('--alpha', type=float, default=1.0) + parser.add_argument('--depths', type=int, nargs='+', default=[4]) + parser.add_argument('--M_values', type=int, nargs='+', default=[4, 8]) + parser.add_argument('--epochs', type=int, default=80) + parser.add_argument('--steps_per_epoch', type=int, default=50) + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--warmup_ratio', type=float, default=0.05) + parser.add_argument('--term_grad_weight', type=float, default=1.0) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=4) + parser.add_argument('--sigma_bridge', type=float, default=0.05) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--pert_eps', type=float, default=1e-3) + parser.add_argument('--pert_beta', type=float, default=1.0) + parser.add_argument('--seeds', type=int, nargs='+', default=[42]) + parser.add_argument('--gpu', type=int, default=2) + parser.add_argument('--output_dir', type=str, default='results/vector_audit') + parser.add_argument('--full_audit', action='store_true', + help='Run sanity checks for all seeds (default: first seed only)') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() -- cgit v1.2.3