""" Phase A: Frozen CIFAR Credit Recovery. Goal: Separate "estimator problem" from "forward exploitability problem". 1. Train a BP reference network to convergence, freeze it. 2. On frozen features, train credit estimators (state bridge, scalar CB with eT/deltaL). 3. Evaluate Gamma, rho, nudging per layer. This answers: can the credit estimator recover useful local credit from fixed representations? """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import copy import time sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema from models.state_bridge import StateBridgeNet from metrics.credit_metrics import ( cosine_similarity_batch, perturbation_correlation, nudging_test, offline_bp_cosine ) def get_cifar10(batch_size=128): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, test_loader def evaluate(model, test_loader, device): model.eval() correct, total = 0, 0 with torch.no_grad(): for x, y in test_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) logits = model(x) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) return correct / total # ============================================================================= # Step 1: Train BP reference network # ============================================================================= def train_bp_reference(model, train_loader, test_loader, device, epochs=100, lr=1e-3, wd=0.01): """Train BP reference to convergence.""" optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) for epoch in range(1, epochs + 1): model.train() total_loss, correct, total = 0, 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) logits = model(x) loss = F.cross_entropy(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * x.size(0) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) scheduler.step() if epoch % 10 == 0 or epoch == 1: test_acc = evaluate(model, test_loader, device) print(f" [BP ref] Epoch {epoch}: loss={total_loss/total:.4f}, " f"train_acc={correct/total:.4f}, test_acc={test_acc:.4f}") test_acc = evaluate(model, test_loader, device) print(f" [BP ref] Final test accuracy: {test_acc:.4f}") return test_acc # ============================================================================= # Step 2: Train estimators on frozen features # ============================================================================= def train_state_bridge_frozen(model, train_loader, device, args): """Train state bridge on frozen BP features.""" d = model.d_hidden L = model.num_blocks num_classes = 10 state_pred = StateBridgeNet( d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3 ).to(device) state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb) model.eval() for epoch in range(1, args.estimator_epochs + 1): state_pred.train() total_loss = 0 n = 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() hL_det = hiddens[-1].detach() # Train state predictor state_loss = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0) state_loss = state_loss + (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean() state_loss = state_loss / L state_opt.zero_grad() state_loss.backward() state_opt.step() total_loss += state_loss.item() * batch n += batch if epoch % 20 == 0 or epoch == 1: print(f" [SB] Epoch {epoch}: state_loss={total_loss/n:.6f}") return state_pred def train_scalar_cb_frozen(model, train_loader, device, args, s_type='eT'): """ Train scalar credit bridge on frozen BP features. s_type: 'eT' (softmax error, dim=10) or 'deltaL' (grad_{h_L} CE, dim=d_hidden) """ d = model.d_hidden L = model.num_blocks num_classes = 10 if s_type == 'eT': s_dim = num_classes elif s_type == 'deltaL': s_dim = d else: raise ValueError(f"Unknown s_type: {s_type}") value_net = ValueNet( d_hidden=d, s_dim=s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3 ).to(device) value_net_ema = create_ema_model(value_net) value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) lam = args.lam K_samples = args.K sigma_bridge = args.sigma_bridge ema_momentum = args.ema_momentum term_grad_weight = args.term_grad_weight model.eval() for epoch in range(1, args.estimator_epochs + 1): value_net.train() total_vloss = 0 total_term = 0 total_tgrad = 0 total_bridge = 0 n = 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 true_loss = F.cross_entropy(logits, y, reduction='none').detach() hL_det = hiddens[-1].detach() # Compute s (conditioning code) if s_type == 'eT': s = e_T.detach() elif s_type == 'deltaL': # delta_L = grad_{h_L} CE (output-layer-local, allowed) hL_req = hL_det.clone().requires_grad_(True) logits_for_s = model.out_head(model.out_ln(hL_req)) ce_for_s = F.cross_entropy(logits_for_s, y, reduction='sum') delta_L = torch.autograd.grad(ce_for_s, hL_req, create_graph=False)[0].detach() s = delta_L # Terminal boundary t_L = torch.ones(batch, device=device) V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() # Terminal gradient matching loss_tgrad = torch.tensor(0.0, device=device) if term_grad_weight > 0: hL_req = hL_det.clone().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] # Exact terminal gradient (output-layer-local) hL_req2 = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(model.out_ln(hL_req2)) ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() # Bridge consistency loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) t_l_next = torch.full((batch,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(K_samples): noise = sigma_bridge * torch.randn_like(h_next_det) V_next = value_net_ema(h_next_det + noise, t_l_next, s) log_terms.append(-V_next / lam) log_stack = torch.stack(log_terms, dim=-1) V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples)) loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() loss_bridge = loss_bridge / L value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad value_opt.zero_grad() value_loss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) value_opt.step() update_ema(value_net, value_net_ema, ema_momentum) total_vloss += value_loss.item() * batch total_term += loss_term.item() * batch total_tgrad += loss_tgrad.item() * batch total_bridge += (loss_bridge.item() if isinstance(loss_bridge, torch.Tensor) else loss_bridge) * batch n += batch if epoch % 20 == 0 or epoch == 1: print(f" [CB_{s_type}] Epoch {epoch}: vloss={total_vloss/n:.6f}, " f"term={total_term/n:.6f}, tgrad={total_tgrad/n:.6f}, bridge={total_bridge/n:.6f}") return value_net, value_net_ema # ============================================================================= # Step 3: Evaluate credit quality on frozen features # ============================================================================= def evaluate_credits(model, test_loader, device, estimators, args): """ Evaluate credit quality for all estimators on frozen BP features. Args: estimators: dict of {name: {'type': 'sb'/'cb', 'net': ..., 's_type': ...}} Returns: dict of {name: {per-layer metrics}} """ model.eval() d = model.d_hidden L = model.num_blocks num_classes = 10 # Accumulate over multiple test batches for robust statistics all_results = {} for name in estimators: all_results[name] = { 'bp_cosine': [[] for _ in range(L)], 'perturbation_rho': [0.0] * L, 'nudging_0.001': [0.0] * L, 'nudging_0.003': [0.0] * L, 'nudging_0.01': [0.0] * L, } # Also add DFA baseline dfa_Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] all_results['dfa'] = { 'bp_cosine': [[] for _ in range(L)], 'perturbation_rho': [0.0] * L, 'nudging_0.001': [0.0] * L, 'nudging_0.003': [0.0] * L, 'nudging_0.01': [0.0] * L, } n_batches_diag = min(10, len(test_loader)) # Use multiple batches batch_idx = 0 for x, y in test_loader: if batch_idx >= n_batches_diag: break batch_idx += 1 x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) # Get BP gradients (ground truth for Gamma) # Temporarily enable grad on model params for BP gradient computation for p in model.parameters(): p.requires_grad_(True) model.zero_grad() logits_bp, hiddens_bp = model(x, return_hidden=True) for l in range(L + 1): hiddens_bp[l].retain_grad() loss_bp = F.cross_entropy(logits_bp, y) loss_bp.backward() bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} # Re-freeze model for p in model.parameters(): p.requires_grad_(False) # Clean forward with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s_eT = e_T.detach() hL_det = hiddens[-1].detach() # Compute delta_L for deltaL conditioning hL_req = hL_det.clone().requires_grad_(True) logits_for_delta = model.out_head(model.out_ln(hL_req)) ce_for_delta = F.cross_entropy(logits_for_delta, y, reduction='sum') delta_L = torch.autograd.grad(ce_for_delta, hL_req, create_graph=False)[0].detach() for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) # Forward function for perturbation and nudging def make_fwd_fn(start_l): def fwd_fn(h): with torch.no_grad(): curr = h for i in range(start_l, L): curr = curr + model.blocks[i](curr) out = model.out_head(model.out_ln(curr)) return F.cross_entropy(out, y, reduction='none') return fwd_fn fwd_fn = make_fwd_fn(l) # --- DFA credit --- a_dfa = (s_eT @ dfa_Bs[l].T).detach() bp_cos_dfa = cosine_similarity_batch(a_dfa, bp_grads[l]) all_results['dfa']['bp_cosine'][l].append(bp_cos_dfa) if batch_idx == 1: # Only compute rho/nudging on first batch (expensive) rho_dfa = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=32) all_results['dfa']['perturbation_rho'][l] = rho_dfa for eta in [0.001, 0.003, 0.01]: nud = nudging_test(h_l, a_dfa, fwd_fn, eta=eta) all_results['dfa'][f'nudging_{eta}'][l] = nud # --- Estimator credits --- for name, est in estimators.items(): if est['type'] == 'sb': net = est['net'] net.eval() h_l_req = h_l.clone().requires_grad_(True) pred_hL = net(h_l_req, t_l, s_eT) pred_logits = model.out_head(model.out_ln(pred_hL)) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() elif est['type'] == 'cb': net = est['net'] net.eval() s_type = est['s_type'] if s_type == 'eT': s = s_eT elif s_type == 'deltaL': s = delta_L else: raise ValueError(f"Unknown s_type: {s_type}") h_l_req = h_l.clone().requires_grad_(True) V_l = net(h_l_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() else: raise ValueError(f"Unknown estimator type: {est['type']}") bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) all_results[name]['bp_cosine'][l].append(bp_cos) if batch_idx == 1: rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=32) all_results[name]['perturbation_rho'][l] = rho for eta in [0.001, 0.003, 0.01]: nud = nudging_test(h_l, a_l, fwd_fn, eta=eta) all_results[name][f'nudging_{eta}'][l] = nud # Average bp_cosine over batches for name in all_results: for l in range(L): vals = all_results[name]['bp_cosine'][l] all_results[name]['bp_cosine'][l] = float(np.mean(vals)) if vals else 0.0 return all_results def evaluate_state_bridge_pred_error(model, state_pred, test_loader, device): """Evaluate state bridge's terminal state prediction error.""" model.eval() state_pred.eval() L = model.num_blocks total_error = [0.0] * L n = 0 for x, y in test_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() hL = hiddens[-1] for l in range(L): h_l = hiddens[l] t_l = torch.full((batch,), l / L, device=x.device) pred_hL = state_pred(h_l, t_l, s) error = ((pred_hL - hL) ** 2).sum(dim=-1).mean().item() total_error[l] += error * batch n += batch return [e / n for e in total_error] # ============================================================================= # Main experiment # ============================================================================= def run_experiment(args): device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") os.makedirs(args.output_dir, exist_ok=True) torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) train_loader, test_loader = get_cifar10(batch_size=args.batch_size) input_dim = 32 * 32 * 3 num_classes = 10 # ----- Step 1: Train BP reference ----- print(f"\n{'='*60}") print(f"Step 1: Train BP reference (L={args.num_blocks}, d={args.d_hidden})") print(f"{'='*60}") bp_ckpt_path = os.path.join(args.output_dir, f'bp_ref_L{args.num_blocks}_d{args.d_hidden}_s{args.seed}.pt') model = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) if os.path.exists(bp_ckpt_path) and not args.retrain_bp: print(f" Loading BP reference from {bp_ckpt_path}") model.load_state_dict(torch.load(bp_ckpt_path, map_location=device)) bp_acc = evaluate(model, test_loader, device) print(f" BP reference test accuracy: {bp_acc:.4f}") else: bp_acc = train_bp_reference(model, train_loader, test_loader, device, epochs=args.bp_epochs, lr=args.lr, wd=args.wd) torch.save(model.state_dict(), bp_ckpt_path) print(f" Saved BP reference to {bp_ckpt_path}") # Freeze the model completely model.eval() for p in model.parameters(): p.requires_grad_(False) # ----- Step 2: Train estimators ----- print(f"\n{'='*60}") print(f"Step 2: Train estimators ({args.estimator_epochs} epochs each)") print(f"{'='*60}") estimators = {} # 2a. State Bridge with s=eT print("\n--- State Bridge (s=eT) ---") torch.manual_seed(args.seed + 1000) sb = train_state_bridge_frozen(model, train_loader, device, args) estimators['sb_eT'] = {'type': 'sb', 'net': sb, 's_type': 'eT'} # 2b. Scalar CB with s=eT print("\n--- Scalar CB (s=eT) ---") torch.manual_seed(args.seed + 2000) cb_eT, cb_eT_ema = train_scalar_cb_frozen(model, train_loader, device, args, s_type='eT') estimators['cb_eT'] = {'type': 'cb', 'net': cb_eT, 's_type': 'eT'} # 2c. Scalar CB with s=deltaL print("\n--- Scalar CB (s=deltaL) ---") torch.manual_seed(args.seed + 3000) cb_dL, cb_dL_ema = train_scalar_cb_frozen(model, train_loader, device, args, s_type='deltaL') estimators['cb_deltaL'] = {'type': 'cb', 'net': cb_dL, 's_type': 'deltaL'} # ----- Step 3: Evaluate ----- print(f"\n{'='*60}") print(f"Step 3: Evaluate credit quality") print(f"{'='*60}") results = evaluate_credits(model, test_loader, device, estimators, args) # State bridge prediction error sb_pred_error = evaluate_state_bridge_pred_error(model, sb, test_loader, device) # ----- Print results ----- L = args.num_blocks print(f"\n{'='*60}") print(f"RESULTS: Frozen CIFAR Credit Recovery (L={L}, d={args.d_hidden}, seed={args.seed})") print(f"BP reference test accuracy: {bp_acc:.4f}") print(f"{'='*60}") # Summary table methods = ['dfa', 'sb_eT', 'cb_eT', 'cb_deltaL'] method_labels = { 'dfa': 'DFA (random)', 'sb_eT': 'State Bridge (eT)', 'cb_eT': 'Scalar CB (eT)', 'cb_deltaL': 'Scalar CB (deltaL)', } print(f"\n{'Method':<25} {'mean Gamma':>12} {'mean rho':>12} {'mean nudge':>12}") print("-" * 65) summary = {} for m in methods: r = results[m] mean_gamma = np.mean(r['bp_cosine']) mean_rho = np.mean(r['perturbation_rho']) mean_nudge = np.mean(r['nudging_0.003']) summary[m] = { 'mean_gamma': float(mean_gamma), 'mean_rho': float(mean_rho), 'mean_nudge': float(mean_nudge), } print(f"{method_labels[m]:<25} {mean_gamma:>12.4f} {mean_rho:>12.4f} {mean_nudge:>12.6f}") # Per-layer detail print(f"\n--- Per-layer Gamma ---") header = f"{'Layer':<8}" for m in methods: header += f" {method_labels[m]:>16}" print(header) for l in range(L): row = f" {l:<6}" for m in methods: row += f" {results[m]['bp_cosine'][l]:>16.4f}" print(row) print(f"\n--- Per-layer rho ---") print(header) for l in range(L): row = f" {l:<6}" for m in methods: row += f" {results[m]['perturbation_rho'][l]:>16.4f}" print(row) print(f"\n--- Per-layer nudge (eta=0.003) ---") print(header) for l in range(L): row = f" {l:<6}" for m in methods: row += f" {results[m]['nudging_0.003'][l]:>16.6f}" print(row) print(f"\n--- State Bridge prediction error per layer ---") for l in range(L): print(f" Layer {l}: {sb_pred_error[l]:.6f}") # ----- Save all results ----- save_data = { 'config': { 'num_blocks': args.num_blocks, 'd_hidden': args.d_hidden, 'seed': args.seed, 'bp_epochs': args.bp_epochs, 'estimator_epochs': args.estimator_epochs, 'lr_fb': args.lr_fb, 'lam': args.lam, 'K': args.K, 'sigma_bridge': args.sigma_bridge, 'ema_momentum': args.ema_momentum, 'term_grad_weight': args.term_grad_weight, }, 'bp_acc': float(bp_acc), 'summary': summary, 'per_layer': {}, 'sb_pred_error': sb_pred_error, } for m in methods: save_data['per_layer'][m] = { 'bp_cosine': results[m]['bp_cosine'], 'perturbation_rho': results[m]['perturbation_rho'], 'nudging_0.001': results[m]['nudging_0.001'], 'nudging_0.003': results[m]['nudging_0.003'], 'nudging_0.01': results[m]['nudging_0.01'], } out_path = os.path.join(args.output_dir, f'frozen_L{args.num_blocks}_d{args.d_hidden}_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(save_data, f, indent=2) print(f"\nResults saved to {out_path}") # ----- Judgment ----- print(f"\n{'='*60}") print("JUDGMENT") print(f"{'='*60}") best_cb = max(summary['cb_eT']['mean_rho'], summary['cb_deltaL']['mean_rho']) dfa_rho = summary['dfa']['mean_rho'] best_cb_gamma = max(summary['cb_eT']['mean_gamma'], summary['cb_deltaL']['mean_gamma']) dfa_gamma = summary['dfa']['mean_gamma'] if best_cb > dfa_rho + 0.02 and best_cb_gamma > dfa_gamma: print("POSITIVE: Scalar CB recovers credit that is clearly better than DFA.") print(" -> Bottleneck is in forward exploitability / local update, not estimator.") print(" -> Next: Phase B (online shallow CIFAR).") elif best_cb > 0.02: print("MARGINAL: Scalar CB shows some signal but not clearly better than DFA.") print(" -> Need more investigation before concluding estimator is the bottleneck.") else: print("NEGATIVE: Scalar CB cannot recover useful credit even on frozen features.") print(" -> Estimator parameterization is the bottleneck.") print(" -> Next: Phase C (direct vector field pilot).") return save_data def main(): parser = argparse.ArgumentParser(description='Frozen CIFAR Credit Recovery') parser.add_argument('--num_blocks', type=int, default=4) parser.add_argument('--d_hidden', type=int, default=256) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--bp_epochs', type=int, default=100, help='Epochs to train BP reference') parser.add_argument('--estimator_epochs', type=int, default=100, help='Epochs to train each estimator on frozen features') parser.add_argument('--lr', type=float, default=1e-3, help='LR for BP reference') parser.add_argument('--lr_fb', type=float, default=1e-3, help='LR for estimators') parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--lam', type=float, default=0.1) parser.add_argument('--K', type=int, default=4) parser.add_argument('--sigma_bridge', type=float, default=0.05) parser.add_argument('--ema_momentum', type=float, default=0.995) parser.add_argument('--term_grad_weight', type=float, default=1.0) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--gpu', type=int, default=2) parser.add_argument('--output_dir', type=str, default='results/frozen_cifar') parser.add_argument('--retrain_bp', action='store_true') args = parser.parse_args() run_experiment(args) if __name__ == '__main__': main()