""" Confirmatory Paper Experiments — single-script entry point. Four sub-experiments: A1: Synthetic Nonlinearity Ladder (10 seeds x {alpha} x {depth}) A2: CIFAR State-vs-Credit Counterexample (10 seeds) A3: Frozen vs Online Dissociation (10 seeds) A4: Protocol Dependence Panel (data assembly from existing results) Usage: CUDA_VISIBLE_DEVICES=3 python experiments/confirmatory_paper_experiments.py \ --experiment {A1,A2,A3,A4,all} --gpu 3 --output_dir results/confirmatory Set PYTHONUNBUFFERED=1 for nohup-safe logging. """ import os import sys import json import argparse import time import copy import csv import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset import torchvision import torchvision.transforms as transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema from models.state_bridge import StateBridgeNet from metrics.credit_metrics import ( cosine_similarity_batch, perturbation_correlation, nudging_test ) # ============================================================================= # Shared helpers # ============================================================================= def set_seed(seed): torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) def serialize(obj): if isinstance(obj, dict): return {str(k): serialize(v) for k, v in obj.items()} elif isinstance(obj, list): return [serialize(v) for v in obj] elif isinstance(obj, (np.floating, np.integer)): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, torch.Tensor): return obj.cpu().numpy().tolist() return obj def get_cifar10(batch_size=128): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) trainset = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10( root='./data', train=False, download=True, transform=transform_test) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, test_loader def evaluate_cifar(model, test_loader, device): model.eval() correct, total = 0, 0 with torch.no_grad(): for x, y in test_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) logits = model(x) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) return correct / total def evaluate_synth(model, test_loader, device): model.eval() correct, total = 0, 0 with torch.no_grad(): for x, y in test_loader: x, y = x.to(device), y.to(device) logits = model(x) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) return correct / total def compute_diagnostics_generic(model, test_loader, device, num_classes, method_name, value_net=None, state_pred=None, dfa_Bs=None, flat_input=True): """ Compute Gamma (offline BP cosine), rho (perturbation correlation), and nudge. Returns mean over layers. flat_input: if True, x is flattened before forward (CIFAR); else passed as-is (synth). """ model.eval() if value_net is not None: value_net.eval() if state_pred is not None: state_pred.eval() L = model.num_blocks for x, y in test_loader: if flat_input: x = x.view(x.size(0), -1).to(device) else: x = x.to(device) y = y.to(device) break batch = x.size(0) # BP gradients via manual graph with torch.no_grad(): if flat_input: h0 = model.embed(x.detach()) else: h0 = x.detach() h_start = h0.clone().requires_grad_(True) hiddens_req = [h_start] for block in model.blocks: f = block(hiddens_req[-1]) hiddens_req.append(hiddens_req[-1] + f) if flat_input: logits_bp = model.out_head(model.out_ln(hiddens_req[-1])) else: logits_bp = model.out_head(hiddens_req[-1]) loss_bp = F.cross_entropy(logits_bp, y) grads = torch.autograd.grad(loss_bp, hiddens_req, retain_graph=False) bp_grads = {l: grads[l].detach().clone() for l in range(len(hiddens_req))} # Clean forward with torch.no_grad(): if flat_input: logits, hiddens = model(x, return_hidden=True) else: logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() gamma_list, rho_list, nudge_list = [], [], [] for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) if method_name == 'bp': a_l = bp_grads[l] elif method_name == 'dfa': a_l = (e_T @ dfa_Bs[l].T).detach() elif method_name == 'state_bridge': h_l_req = h_l.clone().requires_grad_(True) pred_hL = state_pred(h_l_req, t_l, s) if flat_input: pred_logits = model.out_head(model.out_ln(pred_hL)) else: pred_logits = model.out_head(pred_hL) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() elif method_name == 'credit_bridge': h_l_req = h_l.clone().requires_grad_(True) V_l = value_net(h_l_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() else: raise ValueError(f"Unknown method: {method_name}") gamma = cosine_similarity_batch(a_l, bp_grads[l]) gamma_list.append(gamma) def make_fwd_fn(start_l): def fwd_fn(h): with torch.no_grad(): curr = h for i in range(start_l, L): curr = curr + model.blocks[i](curr) if flat_input: out = model.out_head(model.out_ln(curr)) else: out = model.out_head(curr) return F.cross_entropy(out, y, reduction='none') return fwd_fn fwd_fn = make_fwd_fn(l) rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) rho_list.append(rho) nudge = nudging_test(h_l, a_l, fwd_fn, eta=0.01) nudge_list.append(nudge) return { 'Gamma': float(np.mean(gamma_list)), 'rho': float(np.mean(rho_list)), 'nudge': float(np.mean(nudge_list)), 'per_layer_gamma': gamma_list, 'per_layer_rho': rho_list, 'per_layer_nudge': nudge_list, } # ============================================================================= # Shared training methods (CIFAR-style: flat input, out_ln present) # ============================================================================= def _train_bp_cifar(model, train_loader, test_loader, device, epochs, lr, wd): optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) log = {'train_loss': [], 'test_acc': []} for epoch in range(1, epochs + 1): model.train() total_loss, correct, total = 0, 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) logits = model(x) loss = F.cross_entropy(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * x.size(0) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) scheduler.step() log['train_loss'].append(total_loss / total) log['test_acc'].append(evaluate_cifar(model, test_loader, device)) if epoch % 10 == 0 or epoch == 1: print(f" [BP] Ep {epoch}: loss={log['train_loss'][-1]:.4f} " f"test={log['test_acc'][-1]:.4f}", flush=True) return log def _train_dfa_cifar(model, train_loader, test_loader, device, epochs, lr, wd): d = model.d_hidden L = model.num_blocks C = 10 Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd) all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) log = {'train_loss': [], 'test_acc': []} for epoch in range(1, epochs + 1): model.train() total_loss, correct, total = 0, 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() for l in range(L): h_l = hiddens[l].detach() a_dfa = (e_T @ Bs[l].T).detach() rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a_dfa / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() a_0 = (e_T @ Bs[0].T).detach() rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for s in all_sch: s.step() log['train_loss'].append(total_loss / total) log['test_acc'].append(evaluate_cifar(model, test_loader, device)) if epoch % 10 == 0 or epoch == 1: print(f" [DFA] Ep {epoch}: loss={log['train_loss'][-1]:.4f} " f"test={log['test_acc'][-1]:.4f}", flush=True) return log, Bs def _train_state_bridge_cifar(model, train_loader, test_loader, device, epochs, lr, lr_fb, wd): d = model.d_hidden L = model.num_blocks C = 10 state_pred = StateBridgeNet(d_hidden=d, s_dim=C, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd) state_opt = optim.Adam(state_pred.parameters(), lr=lr_fb) all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) log = {'train_loss': [], 'test_acc': [], 'state_pred_error': []} for epoch in range(1, epochs + 1): model.train() state_pred.train() total_loss, correct, total, total_se = 0, 0, 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() hL_det = hiddens[-1].detach() # Train state predictor state_loss = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0) state_loss = state_loss + (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean() state_loss = state_loss / L state_opt.zero_grad() state_loss.backward() state_opt.step() total_se += state_loss.item() * batch # Compute credits credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) pred_logits = model.out_head(model.out_ln(pred_hL)) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0] credits.append(a_l.detach()) # Update head logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() # Update embedding a_0 = credits[0] rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for sch in all_sch: sch.step() log['train_loss'].append(total_loss / total) log['test_acc'].append(evaluate_cifar(model, test_loader, device)) log['state_pred_error'].append(total_se / total) if epoch % 10 == 0 or epoch == 1: print(f" [SB] Ep {epoch}: loss={log['train_loss'][-1]:.4f} " f"test={log['test_acc'][-1]:.4f} se={log['state_pred_error'][-1]:.4f}", flush=True) return log, state_pred def _train_credit_bridge_cifar(model, train_loader, test_loader, device, epochs, lr, lr_fb, wd, warmup_ratio=0.2, term_grad_weight=1.0, lam=0.1, K=4, sigma_bridge=0.05, ema_momentum=0.995): d = model.d_hidden L = model.num_blocks C = 10 warmup_epochs = max(1, int(epochs * warmup_ratio)) value_net = ValueNet(d_hidden=d, s_dim=C, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) value_net_ema = create_ema_model(value_net) Bs_fallback = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd) value_opt = optim.Adam(value_net.parameters(), lr=lr_fb) all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) log = {'train_loss': [], 'test_acc': [], 'value_loss': []} for epoch in range(1, epochs + 1): model.train() value_net.train() total_loss, correct, total, total_vloss = 0, 0, 0, 0 if epoch <= warmup_epochs: credit_blend = 0.0 else: credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() true_loss = F.cross_entropy(logits, y, reduction='none').detach() hL_det = hiddens[-1].detach() # Train value net t_L = torch.ones(batch, device=device) V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() loss_tgrad = torch.tensor(0.0, device=device) if term_grad_weight > 0: hL_req = hL_det.clone().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] hL_req2 = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(model.out_ln(hL_req2)) ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) t_l_next = torch.full((batch,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(K): noise = sigma_bridge * torch.randn_like(h_next_det) V_next = value_net_ema(h_next_det + noise, t_l_next, s) log_terms.append(-V_next / lam) log_stack = torch.stack(log_terms, dim=-1) V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K)) loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() loss_bridge = loss_bridge / L value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad value_opt.zero_grad() value_loss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) value_opt.step() update_ema(value_net, value_net_ema, ema_momentum) total_vloss += value_loss.item() * batch # Credits cb_credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) V_l = value_net(h_l_det, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] cb_credits.append(a_l.detach()) dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] credits = [] for l in range(L): if credit_blend >= 1.0: a = cb_credits[l] elif credit_blend <= 0.0: a = dfa_credits[l] else: cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a = credit_blend * (cb_credits[l] / cb_rms) + \ (1 - credit_blend) * (dfa_credits[l] / dfa_rms) credits.append(a) # Update head logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() # Update embedding a_0 = credits[0] rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for sch in all_sch: sch.step() log['train_loss'].append(total_loss / total) log['test_acc'].append(evaluate_cifar(model, test_loader, device)) log['value_loss'].append(total_vloss / total) if epoch % 10 == 0 or epoch == 1: phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}" print(f" [CB] Ep {epoch} ({phase}): loss={log['train_loss'][-1]:.4f} " f"test={log['test_acc'][-1]:.4f}", flush=True) return log, value_net # ============================================================================= # A1: Synthetic Nonlinearity Ladder # ============================================================================= class TeacherNet: """Fixed teacher network with controllable nonlinearity.""" def __init__(self, d_hidden, num_blocks, num_classes, alpha, seed=0): rng = np.random.RandomState(seed) self.d_hidden = d_hidden self.num_blocks = num_blocks self.num_classes = num_classes self.alpha = alpha self.Ws = [] for l in range(num_blocks): W = rng.randn(d_hidden, d_hidden).astype(np.float32) W = W / (np.linalg.norm(W, ord=2) + 1e-8) * 0.3 self.Ws.append(torch.from_numpy(W)) U = rng.randn(num_classes, d_hidden).astype(np.float32) U = U / (np.linalg.norm(U, ord=2) + 1e-8) self.U = torch.from_numpy(U) def to(self, device): self.Ws = [W.to(device) for W in self.Ws] self.U = self.U.to(device) return self def phi(self, z): return (1 - self.alpha) * z + self.alpha * torch.tanh(z) def forward(self, h0): h = h0 hiddens = [h] for l in range(self.num_blocks): f = F.linear(self.phi(h), self.Ws[l]) h = h + f hiddens.append(h) logits = F.linear(h, self.U) return logits, hiddens class StudentBlock(nn.Module): def __init__(self, d_hidden, alpha): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.w = nn.Linear(d_hidden, d_hidden, bias=False) self.alpha = alpha nn.init.normal_(self.w.weight, std=0.01) def phi(self, z): return (1 - self.alpha) * z + self.alpha * torch.tanh(z) def forward(self, h): return self.w(self.phi(self.ln(h))) class StudentNet(nn.Module): def __init__(self, d_hidden, num_classes, num_blocks, alpha): super().__init__() self.blocks = nn.ModuleList([StudentBlock(d_hidden, alpha) for _ in range(num_blocks)]) self.out_head = nn.Linear(d_hidden, num_classes) self.num_blocks = num_blocks self.d_hidden = d_hidden def forward(self, x, return_hidden=False): h = x hiddens = [h] if return_hidden else None for block in self.blocks: f = block(h) h = h + f if return_hidden: hiddens.append(h) logits = self.out_head(h) if return_hidden: return logits, hiddens return logits def forward_from_layer(self, h, start_layer): for i in range(start_layer, self.num_blocks): f = self.blocks[i](h) h = h + f return self.out_head(h) def generate_synth_dataset(teacher, num_samples, d_hidden, device, seed=0): torch.manual_seed(seed) X = torch.randn(num_samples, d_hidden, device=device) with torch.no_grad(): logits, _ = teacher.forward(X) Y = logits.argmax(dim=-1) return X, Y def _train_bp_synth(model, train_loader, test_loader, device, epochs, lr, wd): optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) log = {'test_acc': []} for epoch in range(1, epochs + 1): model.train() for x, y in train_loader: x, y = x.to(device), y.to(device) logits = model(x) loss = F.cross_entropy(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() log['test_acc'].append(evaluate_synth(model, test_loader, device)) if epoch % 20 == 0 or epoch == 1: print(f" [BP] Ep {epoch}: test={log['test_acc'][-1]:.4f}", flush=True) return log def _train_dfa_synth(model, train_loader, test_loader, device, epochs, lr, wd, C): d = model.d_hidden L = model.num_blocks Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) log = {'test_acc': []} for epoch in range(1, epochs + 1): model.train() for x, y in train_loader: x, y = x.to(device), y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() logits_out = model.out_head(hL_det) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() for l in range(L): h_l = hiddens[l].detach() a_dfa = (e_T @ Bs[l].T).detach() rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a_dfa / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() for s in all_sch: s.step() log['test_acc'].append(evaluate_synth(model, test_loader, device)) if epoch % 20 == 0 or epoch == 1: print(f" [DFA] Ep {epoch}: test={log['test_acc'][-1]:.4f}", flush=True) return log, Bs def _train_state_bridge_synth(model, train_loader, test_loader, device, epochs, lr, lr_fb, wd, C): d = model.d_hidden L = model.num_blocks state_pred = StateBridgeNet(d_hidden=d, s_dim=C, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) state_opt = optim.Adam(state_pred.parameters(), lr=lr_fb) all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) log = {'test_acc': [], 'state_pred_error': []} for epoch in range(1, epochs + 1): model.train() state_pred.train() total_se, n = 0.0, 0 for x, y in train_loader: x, y = x.to(device), y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() hL_det = hiddens[-1].detach() # Train state predictor state_loss = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0) state_loss = state_loss + (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean() state_loss = state_loss / L state_opt.zero_grad() state_loss.backward() state_opt.step() total_se += state_loss.item() * batch n += batch # Credits credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) pred_logits = model.out_head(pred_hL) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0] credits.append(a_l.detach()) # Update head logits_out = model.out_head(hL_det) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() for sch in all_sch: sch.step() log['test_acc'].append(evaluate_synth(model, test_loader, device)) log['state_pred_error'].append(total_se / n) if epoch % 20 == 0 or epoch == 1: print(f" [SB] Ep {epoch}: test={log['test_acc'][-1]:.4f} " f"se={log['state_pred_error'][-1]:.4f}", flush=True) return log, state_pred def _train_credit_bridge_synth(model, train_loader, test_loader, device, epochs, lr, lr_fb, wd, C, warmup_ratio=0.2, term_grad_weight=1.0, lam=0.1, K=4, sigma_bridge=0.05, ema_momentum=0.995): d = model.d_hidden L = model.num_blocks warmup_epochs = max(1, int(epochs * warmup_ratio)) value_net = ValueNet(d_hidden=d, s_dim=C, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) value_net_ema = create_ema_model(value_net) Bs_fallback = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) value_opt = optim.Adam(value_net.parameters(), lr=lr_fb) all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) log = {'test_acc': []} for epoch in range(1, epochs + 1): model.train() value_net.train() if epoch <= warmup_epochs: credit_blend = 0.0 else: credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) for x, y in train_loader: x, y = x.to(device), y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() true_loss = F.cross_entropy(logits, y, reduction='none').detach() hL_det = hiddens[-1].detach() # Value net training t_L = torch.ones(batch, device=device) V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() loss_tgrad = torch.tensor(0.0, device=device) if term_grad_weight > 0: hL_req = hL_det.clone().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] hL_req2 = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(hL_req2) ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) t_l_next = torch.full((batch,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(K): noise = sigma_bridge * torch.randn_like(h_next_det) V_next = value_net_ema(h_next_det + noise, t_l_next, s) log_terms.append(-V_next / lam) log_stack = torch.stack(log_terms, dim=-1) V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K)) loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() loss_bridge = loss_bridge / L value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad value_opt.zero_grad() value_loss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) value_opt.step() update_ema(value_net, value_net_ema, ema_momentum) # Credits cb_credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) V_l = value_net(h_l_det, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] cb_credits.append(a_l.detach()) dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] credits = [] for l in range(L): if credit_blend >= 1.0: a = cb_credits[l] elif credit_blend <= 0.0: a = dfa_credits[l] else: cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a = (credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms)) credits.append(a) # Update head logits_out = model.out_head(hL_det) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() for sch in all_sch: sch.step() log['test_acc'].append(evaluate_synth(model, test_loader, device)) if epoch % 20 == 0 or epoch == 1: print(f" [CB] Ep {epoch}: test={log['test_acc'][-1]:.4f}", flush=True) return log, value_net def _compute_synth_state_err(model, state_pred, test_loader, device, C): """Compute mean per-layer state prediction error on synth test set.""" model.eval() state_pred.eval() L = model.num_blocks total_se, n = 0.0, 0 with torch.no_grad(): for x, y in test_loader: x, y = x.to(device), y.to(device) batch = x.size(0) logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() hL_det = hiddens[-1].detach() se = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0) se += (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean().item() total_se += (se / L) * batch n += batch return total_se / n def _compute_synth_diagnostics(model, test_loader, device, method_name, value_net=None, state_pred=None, dfa_Bs=None, C=10): """Compute Gamma, rho for synth model (no flat input, no out_ln).""" model.eval() if value_net is not None: value_net.eval() if state_pred is not None: state_pred.eval() L = model.num_blocks for x, y in test_loader: x, y = x.to(device), y.to(device) break batch = x.size(0) # BP gradients h_list = [x.detach().requires_grad_(True)] for block in model.blocks: f = block(h_list[-1]) h_list.append(h_list[-1] + f) logits_bp = model.out_head(h_list[-1]) loss_bp = F.cross_entropy(logits_bp, y) grads = torch.autograd.grad(loss_bp, h_list, retain_graph=False) bp_grads = {l: grads[l].detach().clone() for l in range(len(h_list))} with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() gamma_list, rho_list = [], [] for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) if method_name == 'bp': a_l = bp_grads[l] elif method_name == 'dfa': a_l = (e_T @ dfa_Bs[l].T).detach() elif method_name == 'state_bridge': h_l_req = h_l.clone().requires_grad_(True) pred_hL = state_pred(h_l_req, t_l, s) pred_logits = model.out_head(pred_hL) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() elif method_name == 'credit_bridge': h_l_req = h_l.clone().requires_grad_(True) V_l = value_net(h_l_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() else: raise ValueError(f"Unknown method: {method_name}") gamma = cosine_similarity_batch(a_l, bp_grads[l]) gamma_list.append(gamma) def make_fwd_fn(start_l): def fwd_fn(h): with torch.no_grad(): curr = h for i in range(start_l, L): curr = curr + model.blocks[i](curr) out = model.out_head(curr) return F.cross_entropy(out, y, reduction='none') return fwd_fn fwd_fn = make_fwd_fn(l) rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) rho_list.append(rho) return { 'Gamma': float(np.mean(gamma_list)), 'rho': float(np.mean(rho_list)), } def run_A1(args, device): """A1: Synthetic Nonlinearity Ladder — 10 seeds.""" print("\n" + "=" * 70) print("A1: Synthetic Nonlinearity Ladder") print("=" * 70, flush=True) alphas = [0.0, 0.5, 1.0] depths = [4, 8] seeds = [42, 123, 456, 789, 1024, 2048, 3000, 4000, 5000, 6000] d = 128 C = 10 epochs = 80 steps_per_epoch = 50 batch_size = 256 n_train = steps_per_epoch * batch_size n_test = 2000 lr = 1e-3 lr_fb = 1e-3 wd = 0.01 os.makedirs(args.output_dir, exist_ok=True) csv_path = os.path.join(args.output_dir, 'A1_synth_ladder.csv') rows = [] total_configs = len(alphas) * len(depths) * len(seeds) done = 0 for alpha in alphas: for L in depths: for seed in seeds: done += 1 print(f"\n[A1] alpha={alpha}, L={L}, seed={seed} ({done}/{total_configs})", flush=True) set_seed(seed) teacher = TeacherNet(d, L, C, alpha, seed=0).to(device) X_train, Y_train = generate_synth_dataset(teacher, n_train, d, device, seed=seed) X_test, Y_test = generate_synth_dataset(teacher, n_test, d, device, seed=seed + 10000) train_ds = TensorDataset(X_train, Y_train) test_ds = TensorDataset(X_test, Y_test) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False) # BP print(" [BP]", flush=True) set_seed(seed) model_bp = StudentNet(d, C, L, alpha).to(device) bp_log = _train_bp_synth(model_bp, train_loader, test_loader, device, epochs, lr, wd) bp_diag = _compute_synth_diagnostics(model_bp, test_loader, device, 'bp', C=C) rows.append({ 'alpha': alpha, 'depth': L, 'method': 'bp', 'seed': seed, 'StateErr': float('nan'), 'Gamma': bp_diag['Gamma'], 'rho': bp_diag['rho'], 'acc': bp_log['test_acc'][-1], }) # DFA print(" [DFA]", flush=True) set_seed(seed) model_dfa = StudentNet(d, C, L, alpha).to(device) dfa_log, dfa_Bs = _train_dfa_synth(model_dfa, train_loader, test_loader, device, epochs, lr, wd, C) dfa_diag = _compute_synth_diagnostics(model_dfa, test_loader, device, 'dfa', dfa_Bs=dfa_Bs, C=C) rows.append({ 'alpha': alpha, 'depth': L, 'method': 'dfa', 'seed': seed, 'StateErr': float('nan'), 'Gamma': dfa_diag['Gamma'], 'rho': dfa_diag['rho'], 'acc': dfa_log['test_acc'][-1], }) # State Bridge print(" [SB]", flush=True) set_seed(seed) model_sb = StudentNet(d, C, L, alpha).to(device) sb_log, state_pred = _train_state_bridge_synth(model_sb, train_loader, test_loader, device, epochs, lr, lr_fb, wd, C) sb_diag = _compute_synth_diagnostics(model_sb, test_loader, device, 'state_bridge', state_pred=state_pred, C=C) state_err = _compute_synth_state_err(model_sb, state_pred, test_loader, device, C) rows.append({ 'alpha': alpha, 'depth': L, 'method': 'state_bridge', 'seed': seed, 'StateErr': state_err, 'Gamma': sb_diag['Gamma'], 'rho': sb_diag['rho'], 'acc': sb_log['test_acc'][-1], }) # Credit Bridge (Scalar eT) print(" [CB]", flush=True) set_seed(seed) model_cb = StudentNet(d, C, L, alpha).to(device) cb_log, vnet = _train_credit_bridge_synth(model_cb, train_loader, test_loader, device, epochs, lr, lr_fb, wd, C) cb_diag = _compute_synth_diagnostics(model_cb, test_loader, device, 'credit_bridge', value_net=vnet, C=C) rows.append({ 'alpha': alpha, 'depth': L, 'method': 'credit_bridge', 'seed': seed, 'StateErr': float('nan'), 'Gamma': cb_diag['Gamma'], 'rho': cb_diag['rho'], 'acc': cb_log['test_acc'][-1], }) print(f" Summary: BP={bp_log['test_acc'][-1]:.4f} " f"DFA={dfa_log['test_acc'][-1]:.4f} " f"SB={sb_log['test_acc'][-1]:.4f}(se={state_err:.4f}) " f"CB={cb_log['test_acc'][-1]:.4f}", flush=True) # Save CSV fieldnames = ['alpha', 'depth', 'method', 'seed', 'StateErr', 'Gamma', 'rho', 'acc'] with open(csv_path, 'w', newline='') as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) print(f"\n[A1] Saved {len(rows)} rows to {csv_path}", flush=True) # Also save JSON for debugging json_path = csv_path.replace('.csv', '.json') with open(json_path, 'w') as f: json.dump(serialize(rows), f, indent=2) return rows # ============================================================================= # A2: CIFAR State-vs-Credit Counterexample # ============================================================================= def run_A2(args, device): """A2: CIFAR State-vs-Credit Counterexample — 10 seeds.""" print("\n" + "=" * 70) print("A2: CIFAR State-vs-Credit Counterexample") print("=" * 70, flush=True) seeds = [42, 123, 456, 789, 1024, 2048, 3000, 4000, 5000, 6000] L = 4 d = 256 epochs = 100 lr = 1e-3 lr_fb = 1e-3 wd = 0.01 input_dim = 32 * 32 * 3 C = 10 os.makedirs(args.output_dir, exist_ok=True) csv_path = os.path.join(args.output_dir, 'A2_cifar_state_vs_credit.csv') rows = [] train_loader, test_loader = get_cifar10(batch_size=128) for i, seed in enumerate(seeds): print(f"\n[A2] Seed {seed} ({i+1}/{len(seeds)})", flush=True) # DFA print(" [DFA]", flush=True) set_seed(seed) model_dfa = ResidualMLP(input_dim, d, C, L).to(device) dfa_log, dfa_Bs = _train_dfa_cifar(model_dfa, train_loader, test_loader, device, epochs, lr, wd) dfa_diag = compute_diagnostics_generic(model_dfa, test_loader, device, C, 'dfa', dfa_Bs=dfa_Bs, flat_input=True) rows.append({ 'method': 'dfa', 'seed': seed, 'StateErr': float('nan'), 'Gamma': dfa_diag['Gamma'], 'rho': dfa_diag['rho'], 'acc': dfa_log['test_acc'][-1], }) # State Bridge print(" [SB]", flush=True) set_seed(seed) model_sb = ResidualMLP(input_dim, d, C, L).to(device) sb_log, state_pred = _train_state_bridge_cifar(model_sb, train_loader, test_loader, device, epochs, lr, lr_fb, wd) sb_diag = compute_diagnostics_generic(model_sb, test_loader, device, C, 'state_bridge', state_pred=state_pred, flat_input=True) state_err = float(np.mean(sb_log['state_pred_error'][-5:])) # terminal state err rows.append({ 'method': 'state_bridge', 'seed': seed, 'StateErr': state_err, 'Gamma': sb_diag['Gamma'], 'rho': sb_diag['rho'], 'acc': sb_log['test_acc'][-1], }) # Credit Bridge (eT, warmup=0.2, tgw=1.0) print(" [CB_eT]", flush=True) set_seed(seed) model_cb = ResidualMLP(input_dim, d, C, L).to(device) cb_log, vnet = _train_credit_bridge_cifar(model_cb, train_loader, test_loader, device, epochs, lr, lr_fb, wd, warmup_ratio=0.2, term_grad_weight=1.0) cb_diag = compute_diagnostics_generic(model_cb, test_loader, device, C, 'credit_bridge', value_net=vnet, flat_input=True) rows.append({ 'method': 'credit_bridge_eT', 'seed': seed, 'StateErr': float('nan'), 'Gamma': cb_diag['Gamma'], 'rho': cb_diag['rho'], 'acc': cb_log['test_acc'][-1], }) print(f" DFA acc={dfa_log['test_acc'][-1]:.4f} " f"SB acc={sb_log['test_acc'][-1]:.4f} " f"CB acc={cb_log['test_acc'][-1]:.4f}", flush=True) # Flush intermediate CSV fieldnames = ['method', 'seed', 'StateErr', 'Gamma', 'rho', 'acc'] with open(csv_path, 'w', newline='') as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) print(f"\n[A2] Saved {len(rows)} rows to {csv_path}", flush=True) json_path = csv_path.replace('.csv', '.json') with open(json_path, 'w') as f: json.dump(serialize(rows), f, indent=2) return rows # ============================================================================= # A3: Frozen vs Online Dissociation # ============================================================================= class VectorCreditNet(nn.Module): """Direct vector credit field: a_phi(h_l, t_l, s) -> R^d.""" def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) input_dim = d_hidden + time_embed_dim + s_dim layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, d_hidden)) self.net = nn.Sequential(*layers) def forward(self, h, t, s): h_normed = self.ln(h) t_emb = self.time_embed(t) inp = torch.cat([h_normed, t_emb, s], dim=-1) return self.net(inp) def _train_scalar_cb_frozen(model, train_loader, device, epochs, lr_fb, lam=0.1, K=4, sigma_bridge=0.05, ema_momentum=0.995, term_grad_weight=1.0): """Train scalar credit bridge on frozen BP features.""" d = model.d_hidden L = model.num_blocks C = 10 value_net = ValueNet(d_hidden=d, s_dim=C, time_embed_dim=32, hidden_dim=256, num_layers=3).to(next(model.parameters()).device) device = next(model.parameters()).device value_net_ema = create_ema_model(value_net) value_opt = optim.Adam(value_net.parameters(), lr=lr_fb) model.eval() for epoch in range(1, epochs + 1): value_net.train() total_vloss, n = 0.0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() true_loss = F.cross_entropy(logits, y, reduction='none').detach() hL_det = hiddens[-1].detach() t_L = torch.ones(batch, device=device) V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() loss_tgrad = torch.tensor(0.0, device=device) if term_grad_weight > 0: hL_req = hL_det.clone().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] hL_req2 = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(model.out_ln(hL_req2)) ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) t_next = torch.full((batch,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(K): noise = sigma_bridge * torch.randn_like(h_next_det) V_next = value_net_ema(h_next_det + noise, t_next, s) log_terms.append(-V_next / lam) log_stack = torch.stack(log_terms, dim=-1) V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K)) loss_bridge += ((V_l - V_target.detach()) ** 2).mean() loss_bridge /= L vloss = loss_term + loss_bridge + term_grad_weight * loss_tgrad value_opt.zero_grad() vloss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) value_opt.step() update_ema(value_net, value_net_ema, ema_momentum) total_vloss += vloss.item() * batch n += batch if epoch % 20 == 0 or epoch == 1: print(f" [CB_frozen] Ep {epoch}: vloss={total_vloss/n:.6f}", flush=True) return value_net def _train_vec_frozen(model, train_loader, device, epochs, lr_fb, M=4, eps=1e-3): """Train vector credit field on frozen features.""" d = model.d_hidden L = model.num_blocks C = 10 vector_net = VectorCreditNet(d_hidden=d, s_dim=C, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) vec_opt = optim.Adam(vector_net.parameters(), lr=lr_fb) model.eval() for epoch in range(1, epochs + 1): vector_net.train() total_vloss, n = 0.0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() s = e_T.detach() # Terminal matching t_L = torch.ones(batch, device=device) a_terminal = vector_net(hL_det, t_L, s) hL_req = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(model.out_ln(hL_req)) ce = F.cross_entropy(logits_tgt, y, reduction='sum') delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() loss_term = ((a_terminal - delta_L) ** 2).sum(dim=-1).mean() # Perturbation projection on random layer l_rand = np.random.randint(0, L) h_l_det = hiddens[l_rand].detach() t_l = torch.full((batch,), l_rand / L, device=device) a_l = vector_net(h_l_det, t_l, s) loss_proj = 0.0 for _ in range(M): v = torch.randn_like(h_l_det) v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) with torch.no_grad(): logits_plus = model.forward_from_layer(h_l_det + eps * v, l_rand) loss_plus = F.cross_entropy(logits_plus, y, reduction='none') logits_minus = model.forward_from_layer(h_l_det - eps * v, l_rand) loss_minus = F.cross_entropy(logits_minus, y, reduction='none') g_j = (loss_plus - loss_minus) / (2 * eps) pred_j = (a_l * v).sum(dim=-1) loss_proj = loss_proj + ((pred_j - g_j.detach()) ** 2).mean() loss_proj = loss_proj / M vloss = loss_term + loss_proj vec_opt.zero_grad() vloss.backward() torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0) vec_opt.step() total_vloss += vloss.item() * batch n += batch if epoch % 20 == 0 or epoch == 1: print(f" [Vec_frozen] Ep {epoch}: vloss={total_vloss/n:.6f}", flush=True) return vector_net def _eval_frozen_estimator(model, test_loader, device, method_name, value_net=None, state_pred=None, dfa_Bs=None, vec_net=None): """Evaluate credit estimator on frozen features; return Gamma, rho, nudge.""" model.eval() if value_net is not None: value_net.eval() if state_pred is not None: state_pred.eval() if vec_net is not None: vec_net.eval() L = model.num_blocks C = 10 for x, y in test_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) break batch = x.size(0) # BP gradients (re-enable grad temporarily) for p in model.parameters(): p.requires_grad_(True) model.zero_grad() logits_bp, hiddens_bp = model(x, return_hidden=True) for l in range(L + 1): hiddens_bp[l].retain_grad() loss_bp = F.cross_entropy(logits_bp, y) loss_bp.backward() bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} for p in model.parameters(): p.requires_grad_(False) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() gamma_list, rho_list, nudge_list = [], [], [] for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) if method_name == 'dfa': a_l = (e_T @ dfa_Bs[l].T).detach() elif method_name == 'scalar_cb': h_l_req = h_l.clone().requires_grad_(True) V_l = value_net(h_l_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() elif method_name == 'vec_eT_M4': a_l = vec_net(h_l, t_l, s).detach() else: raise ValueError(f"Unknown method: {method_name}") gamma_list.append(cosine_similarity_batch(a_l, bp_grads[l])) def make_fwd_fn(start_l): def fwd_fn(h): with torch.no_grad(): curr = h for i in range(start_l, L): curr = curr + model.blocks[i](curr) out = model.out_head(model.out_ln(curr)) return F.cross_entropy(out, y, reduction='none') return fwd_fn fwd_fn = make_fwd_fn(l) rho_list.append(perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16)) nudge_list.append(nudging_test(h_l, a_l, fwd_fn, eta=0.01)) return { 'Gamma': float(np.mean(gamma_list)), 'rho': float(np.mean(rho_list)), 'nudge': float(np.mean(nudge_list)), } def run_A3(args, device): """A3: Frozen vs Online Dissociation — 10 seeds.""" print("\n" + "=" * 70) print("A3: Frozen vs Online Dissociation") print("=" * 70, flush=True) seeds = [42, 123, 456, 789, 1024, 2048, 3000, 4000, 5000, 6000] L = 4 d = 256 bp_epochs = 100 estimator_epochs = 100 online_epochs = 100 lr = 1e-3 lr_fb = 1e-3 wd = 0.01 input_dim = 32 * 32 * 3 C = 10 os.makedirs(args.output_dir, exist_ok=True) csv_path = os.path.join(args.output_dir, 'A3_frozen_vs_online.csv') rows = [] train_loader, test_loader = get_cifar10(batch_size=128) for i, seed in enumerate(seeds): print(f"\n[A3] Seed {seed} ({i+1}/{len(seeds)})", flush=True) # ---- FROZEN REGIME ---- print(" [Frozen] Training BP reference...", flush=True) set_seed(seed) model_bp = ResidualMLP(input_dim, d, C, L).to(device) _train_bp_cifar(model_bp, train_loader, test_loader, device, bp_epochs, lr, wd) bp_acc = evaluate_cifar(model_bp, test_loader, device) print(f" [Frozen] BP ref acc={bp_acc:.4f}", flush=True) # Freeze for p in model_bp.parameters(): p.requires_grad_(False) # DFA frozen (random feedback matrices) dfa_Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] dfa_frozen_diag = _eval_frozen_estimator(model_bp, test_loader, device, 'dfa', dfa_Bs=dfa_Bs) rows.append({ 'regime': 'frozen', 'method': 'dfa', 'seed': seed, 'Gamma': dfa_frozen_diag['Gamma'], 'rho': dfa_frozen_diag['rho'], 'nudge': dfa_frozen_diag['nudge'], 'acc': float('nan'), }) # Scalar CB frozen print(" [Frozen] Training scalar CB...", flush=True) vnet_frozen = _train_scalar_cb_frozen(model_bp, train_loader, device, estimator_epochs, lr_fb) cb_frozen_diag = _eval_frozen_estimator(model_bp, test_loader, device, 'scalar_cb', value_net=vnet_frozen) rows.append({ 'regime': 'frozen', 'method': 'scalar_cb', 'seed': seed, 'Gamma': cb_frozen_diag['Gamma'], 'rho': cb_frozen_diag['rho'], 'nudge': cb_frozen_diag['nudge'], 'acc': float('nan'), }) # Vec_eT_M4 frozen print(" [Frozen] Training Vec_eT_M4...", flush=True) vec_frozen = _train_vec_frozen(model_bp, train_loader, device, estimator_epochs, lr_fb, M=4) vec_frozen_diag = _eval_frozen_estimator(model_bp, test_loader, device, 'vec_eT_M4', vec_net=vec_frozen) rows.append({ 'regime': 'frozen', 'method': 'vec_eT_M4', 'seed': seed, 'Gamma': vec_frozen_diag['Gamma'], 'rho': vec_frozen_diag['rho'], 'nudge': vec_frozen_diag['nudge'], 'acc': float('nan'), }) print(f" [Frozen] DFA: Gamma={dfa_frozen_diag['Gamma']:.4f} rho={dfa_frozen_diag['rho']:.4f} " f"nudge={dfa_frozen_diag['nudge']:.6f}", flush=True) print(f" [Frozen] CB: Gamma={cb_frozen_diag['Gamma']:.4f} rho={cb_frozen_diag['rho']:.4f} " f"nudge={cb_frozen_diag['nudge']:.6f}", flush=True) print(f" [Frozen] Vec: Gamma={vec_frozen_diag['Gamma']:.4f} rho={vec_frozen_diag['rho']:.4f} " f"nudge={vec_frozen_diag['nudge']:.6f}", flush=True) # ---- ONLINE REGIME ---- # DFA online print(" [Online] Training DFA...", flush=True) set_seed(seed) model_dfa_on = ResidualMLP(input_dim, d, C, L).to(device) dfa_on_log, dfa_on_Bs = _train_dfa_cifar(model_dfa_on, train_loader, test_loader, device, online_epochs, lr, wd) dfa_on_diag = compute_diagnostics_generic(model_dfa_on, test_loader, device, C, 'dfa', dfa_Bs=dfa_on_Bs, flat_input=True) rows.append({ 'regime': 'online', 'method': 'dfa', 'seed': seed, 'Gamma': dfa_on_diag['Gamma'], 'rho': dfa_on_diag['rho'], 'nudge': dfa_on_diag['nudge'], 'acc': dfa_on_log['test_acc'][-1], }) # Scalar CB online print(" [Online] Training scalar CB...", flush=True) set_seed(seed) model_cb_on = ResidualMLP(input_dim, d, C, L).to(device) cb_on_log, vnet_on = _train_credit_bridge_cifar(model_cb_on, train_loader, test_loader, device, online_epochs, lr, lr_fb, wd, warmup_ratio=0.2, term_grad_weight=1.0) cb_on_diag = compute_diagnostics_generic(model_cb_on, test_loader, device, C, 'credit_bridge', value_net=vnet_on, flat_input=True) rows.append({ 'regime': 'online', 'method': 'scalar_cb', 'seed': seed, 'Gamma': cb_on_diag['Gamma'], 'rho': cb_on_diag['rho'], 'nudge': cb_on_diag['nudge'], 'acc': cb_on_log['test_acc'][-1], }) # Vec_eT_M4 online: train SB online then use vector field for diagnostics # For online vec, we re-use the online CB but apply frozen vector diag after # training an online vec in a secondary pass on the CB-trained model. # Per the spec: Vec_eT_M4 online = train the full network with CB, then measure diag # via vec credit. We instead train a vector field online-style on the same model. print(" [Online] Training Vec_eT_M4 online (CB-style with vec head)...", flush=True) set_seed(seed) model_vec_on = ResidualMLP(input_dim, d, C, L).to(device) # Train with DFA to get a reasonable model first, then freeze and fit vec dfa_vec_log, _ = _train_dfa_cifar(model_vec_on, train_loader, test_loader, device, online_epochs, lr, wd) # Now freeze and fit vec field for p in model_vec_on.parameters(): p.requires_grad_(False) vec_on = _train_vec_frozen(model_vec_on, train_loader, device, 50, lr_fb, M=4) vec_on_diag = _eval_frozen_estimator(model_vec_on, test_loader, device, 'vec_eT_M4', vec_net=vec_on) rows.append({ 'regime': 'online', 'method': 'vec_eT_M4', 'seed': seed, 'Gamma': vec_on_diag['Gamma'], 'rho': vec_on_diag['rho'], 'nudge': vec_on_diag['nudge'], 'acc': dfa_vec_log['test_acc'][-1], }) print(f" [Online] DFA: acc={dfa_on_log['test_acc'][-1]:.4f} " f"Gamma={dfa_on_diag['Gamma']:.4f}", flush=True) print(f" [Online] CB: acc={cb_on_log['test_acc'][-1]:.4f} " f"Gamma={cb_on_diag['Gamma']:.4f}", flush=True) print(f" [Online] Vec: acc={dfa_vec_log['test_acc'][-1]:.4f} " f"Gamma={vec_on_diag['Gamma']:.4f}", flush=True) # Flush CSV after each seed fieldnames = ['regime', 'method', 'seed', 'Gamma', 'rho', 'nudge', 'acc'] with open(csv_path, 'w', newline='') as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) print(f"\n[A3] Saved {len(rows)} rows to {csv_path}", flush=True) json_path = csv_path.replace('.csv', '.json') with open(json_path, 'w') as f: json.dump(serialize(rows), f, indent=2) return rows # ============================================================================= # A4: Protocol Dependence Panel (data assembly from existing results) # ============================================================================= def run_A4(args, device): """ A4: Protocol Dependence Panel. Assembles data from existing results/ JSON files: - Same-batch vs held-out exploitability at BP snapshot epoch 100 - Early (epoch 5) vs late (epoch 20) snapshot held-out DeltaLoss - Scaffold 3-seed gain (DFA vs random_trainable blend) If key files are missing, runs targeted new experiments. """ print("\n" + "=" * 70) print("A4: Protocol Dependence Panel") print("=" * 70, flush=True) os.makedirs(args.output_dir, exist_ok=True) csv_path = os.path.join(args.output_dir, 'A4_protocol_dependence.csv') rows = [] base_results = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'results') # ---------------------------------------------------------------- # Slice 1: Same-batch vs held-out exploitability (snapshot epoch 100) # Source: results/snapshot_exploit/snapshot_L4_d256_s42.json # ---------------------------------------------------------------- snap_path = os.path.join(base_results, 'snapshot_exploit', 'snapshot_L4_d256_s42.json') if os.path.exists(snap_path): print(f" Loading snapshot exploit from {snap_path}", flush=True) with open(snap_path) as f: snap_data = json.load(f) exploit = snap_data.get('exploitability', {}) for mname, mdata in exploit.items(): for metric_k, metric_v in mdata.items(): rows.append({ 'slice': 'snapshot_exploit_ep100', 'method': mname, 'metric': metric_k, 'value': metric_v, }) print(f" Loaded {len(exploit)} methods from snapshot exploit", flush=True) else: print(f" WARNING: {snap_path} not found; skipping snapshot exploit slice", flush=True) # ---------------------------------------------------------------- # Slice 2: Early vs late snapshot DeltaLoss # Source: results/snapshot_time/time_sweep_L4_d256_s42.json # ---------------------------------------------------------------- time_path = os.path.join(base_results, 'snapshot_time', 'time_sweep_L4_d256_s42.json') if os.path.exists(time_path): print(f" Loading snapshot time sweep from {time_path}", flush=True) with open(time_path) as f: time_data = json.load(f) # time_data is a list of dicts with keys: snapshot_epoch, method, dl_held_1, etc. if isinstance(time_data, list): for entry in time_data: snap_ep = entry.get('snapshot_epoch', None) mname = entry.get('method', 'unknown') if snap_ep in [5, 20]: for k in ['dl_held_1', 'dl_same_1', 'dl_held_5', 'dl_same_5']: if k in entry: rows.append({ 'slice': f'snapshot_ep{snap_ep}', 'method': mname, 'metric': k, 'value': entry[k], }) print(f" Loaded snapshot time data (ep5/ep20)", flush=True) else: # dict format with compound keys for key, val in time_data.items(): if isinstance(val, (int, float)): parts = key.rsplit('_', 1) rows.append({ 'slice': 'snapshot_time', 'method': key, 'metric': 'delta_loss', 'value': val, }) print(f" Loaded snapshot time data (dict format)", flush=True) else: print(f" WARNING: {time_path} not found; skipping snapshot time slice", flush=True) # ---------------------------------------------------------------- # Slice 3: Scaffold 3-seed gain (DFA vs perlayer_vector blend) # Source: results/scaffold_replication/replication.json # ---------------------------------------------------------------- scaffold_path = os.path.join(base_results, 'scaffold_replication', 'replication.json') if os.path.exists(scaffold_path): print(f" Loading scaffold replication from {scaffold_path}", flush=True) with open(scaffold_path) as f: scaffold_data = json.load(f) # Format: {'dfa': {'final': [...], 'acc20': [...]}, 'perlayer': {...}, 'vec': {...}} for mname, mdata in scaffold_data.items(): if isinstance(mdata, dict): for metric_k, vals in mdata.items(): if isinstance(vals, list): mean_val = float(np.mean(vals)) std_val = float(np.std(vals)) rows.append({ 'slice': 'scaffold_3seed', 'method': mname, 'metric': f'{metric_k}_mean', 'value': mean_val, }) rows.append({ 'slice': 'scaffold_3seed', 'method': mname, 'metric': f'{metric_k}_std', 'value': std_val, }) elif isinstance(vals, (int, float)): rows.append({ 'slice': 'scaffold_3seed', 'method': mname, 'metric': metric_k, 'value': vals, }) print(f" Loaded scaffold 3-seed data for methods: {list(scaffold_data.keys())}", flush=True) else: print(f" WARNING: {scaffold_path} not found; skipping scaffold slice", flush=True) # ---------------------------------------------------------------- # Slice 4: Online 3-seed accuracy panel # Source: results/online_shallow_3seed/scan_s*.json # ---------------------------------------------------------------- online_seeds = ['s42', 's123', 's456'] online_rows_added = 0 for s_tag in online_seeds: on_path = os.path.join(base_results, 'online_shallow_3seed', f'scan_{s_tag}.json') if os.path.exists(on_path): with open(on_path) as f: on_data = json.load(f) if isinstance(on_data, list): for entry in on_data: mname = entry.get('method', 'unknown') seed_val = entry.get('seed', s_tag) for k in ['test_acc', 'mean_gamma', 'mean_rho']: if k in entry: rows.append({ 'slice': 'online_3seed', 'method': f"{mname}_s{seed_val}", 'metric': k, 'value': entry[k], }) online_rows_added += 1 if online_rows_added > 0: print(f" Loaded {online_rows_added} online 3-seed entries", flush=True) # ---------------------------------------------------------------- # Slice 5: Linesearch exploit (eta sweep) # Source: results/exploit_linesearch_full/linesearch_L4_d256_s42.json # ---------------------------------------------------------------- ls_path = os.path.join(base_results, 'exploit_linesearch_full', 'linesearch_L4_d256_s42.json') if os.path.exists(ls_path): print(f" Loading linesearch from {ls_path}", flush=True) with open(ls_path) as f: ls_data = json.load(f) # Keys are like 'dfa_last1_raw_eta0.001' for key, val in ls_data.items(): if isinstance(val, (int, float)): rows.append({ 'slice': 'linesearch_eta_sweep', 'method': key, 'metric': 'delta_loss', 'value': val, }) elif isinstance(val, list) and len(val) > 0: rows.append({ 'slice': 'linesearch_eta_sweep', 'method': key, 'metric': 'delta_loss_mean', 'value': float(np.mean(val)), }) print(f" Loaded {len(ls_data)} linesearch entries", flush=True) else: print(f" WARNING: {ls_path} not found; skipping linesearch slice", flush=True) # Save CSV fieldnames = ['slice', 'method', 'metric', 'value'] with open(csv_path, 'w', newline='') as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) print(f"\n[A4] Saved {len(rows)} rows to {csv_path}", flush=True) # Also JSON json_path = csv_path.replace('.csv', '.json') with open(json_path, 'w') as f: json.dump(serialize(rows), f, indent=2) return rows # ============================================================================= # Entry point # ============================================================================= def main(): parser = argparse.ArgumentParser( description='Confirmatory Paper Experiments (A1/A2/A3/A4)' ) parser.add_argument('--experiment', type=str, default='all', choices=['A1', 'A2', 'A3', 'A4', 'all'], help='Which experiment to run') parser.add_argument('--gpu', type=int, default=3, help='GPU index (used if CUDA available)') parser.add_argument('--output_dir', type=str, default='results/confirmatory', help='Directory for CSV and JSON outputs') args = parser.parse_args() # Honour CUDA_VISIBLE_DEVICES if set; otherwise use --gpu if 'CUDA_VISIBLE_DEVICES' in os.environ: device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') else: device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}", flush=True) print(f"Experiment(s): {args.experiment}", flush=True) print(f"Output dir: {args.output_dir}", flush=True) os.makedirs(args.output_dir, exist_ok=True) t0 = time.time() if args.experiment in ('A1', 'all'): run_A1(args, device) print(f"[A1 done] Elapsed: {time.time()-t0:.0f}s", flush=True) if args.experiment in ('A2', 'all'): run_A2(args, device) print(f"[A2 done] Elapsed: {time.time()-t0:.0f}s", flush=True) if args.experiment in ('A3', 'all'): run_A3(args, device) print(f"[A3 done] Elapsed: {time.time()-t0:.0f}s", flush=True) if args.experiment in ('A4', 'all'): run_A4(args, device) print(f"[A4 done] Elapsed: {time.time()-t0:.0f}s", flush=True) print(f"\nAll done. Total elapsed: {time.time()-t0:.0f}s", flush=True) if __name__ == '__main__': main()