summaryrefslogtreecommitdiff
path: root/experiments/cifar_frozen_credit_recovery.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/cifar_frozen_credit_recovery.py')
-rw-r--r--experiments/cifar_frozen_credit_recovery.py693
1 files changed, 693 insertions, 0 deletions
diff --git a/experiments/cifar_frozen_credit_recovery.py b/experiments/cifar_frozen_credit_recovery.py
new file mode 100644
index 0000000..5d39308
--- /dev/null
+++ b/experiments/cifar_frozen_credit_recovery.py
@@ -0,0 +1,693 @@
+"""
+Phase A: Frozen CIFAR Credit Recovery.
+
+Goal: Separate "estimator problem" from "forward exploitability problem".
+1. Train a BP reference network to convergence, freeze it.
+2. On frozen features, train credit estimators (state bridge, scalar CB with eT/deltaL).
+3. Evaluate Gamma, rho, nudging per layer.
+
+This answers: can the credit estimator recover useful local credit from fixed representations?
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+import time
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test,
+ offline_bp_cosine
+)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def evaluate(model, test_loader, device):
+ model.eval()
+ correct, total = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+# =============================================================================
+# Step 1: Train BP reference network
+# =============================================================================
+def train_bp_reference(model, train_loader, test_loader, device, epochs=100, lr=1e-3, wd=0.01):
+ """Train BP reference to convergence."""
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+
+ for epoch in range(1, epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ scheduler.step()
+ if epoch % 10 == 0 or epoch == 1:
+ test_acc = evaluate(model, test_loader, device)
+ print(f" [BP ref] Epoch {epoch}: loss={total_loss/total:.4f}, "
+ f"train_acc={correct/total:.4f}, test_acc={test_acc:.4f}")
+
+ test_acc = evaluate(model, test_loader, device)
+ print(f" [BP ref] Final test accuracy: {test_acc:.4f}")
+ return test_acc
+
+
+# =============================================================================
+# Step 2: Train estimators on frozen features
+# =============================================================================
+
+def train_state_bridge_frozen(model, train_loader, device, args):
+ """Train state bridge on frozen BP features."""
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = 10
+
+ state_pred = StateBridgeNet(
+ d_hidden=d, s_dim=num_classes, time_embed_dim=32,
+ hidden_dim=256, num_layers=3
+ ).to(device)
+ state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb)
+
+ model.eval()
+ for epoch in range(1, args.estimator_epochs + 1):
+ state_pred.train()
+ total_loss = 0
+ n = 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ hL_det = hiddens[-1].detach()
+
+ # Train state predictor
+ state_loss = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ pred_hL = state_pred(h_l_det, t_l, s)
+ target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0)
+ state_loss = state_loss + (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean()
+ state_loss = state_loss / L
+
+ state_opt.zero_grad()
+ state_loss.backward()
+ state_opt.step()
+ total_loss += state_loss.item() * batch
+ n += batch
+
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [SB] Epoch {epoch}: state_loss={total_loss/n:.6f}")
+
+ return state_pred
+
+
+def train_scalar_cb_frozen(model, train_loader, device, args, s_type='eT'):
+ """
+ Train scalar credit bridge on frozen BP features.
+ s_type: 'eT' (softmax error, dim=10) or 'deltaL' (grad_{h_L} CE, dim=d_hidden)
+ """
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = 10
+
+ if s_type == 'eT':
+ s_dim = num_classes
+ elif s_type == 'deltaL':
+ s_dim = d
+ else:
+ raise ValueError(f"Unknown s_type: {s_type}")
+
+ value_net = ValueNet(
+ d_hidden=d, s_dim=s_dim, time_embed_dim=32,
+ hidden_dim=256, num_layers=3
+ ).to(device)
+ value_net_ema = create_ema_model(value_net)
+ value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb)
+
+ lam = args.lam
+ K_samples = args.K
+ sigma_bridge = args.sigma_bridge
+ ema_momentum = args.ema_momentum
+ term_grad_weight = args.term_grad_weight
+
+ model.eval()
+ for epoch in range(1, args.estimator_epochs + 1):
+ value_net.train()
+ total_vloss = 0
+ total_term = 0
+ total_tgrad = 0
+ total_bridge = 0
+ n = 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # Compute s (conditioning code)
+ if s_type == 'eT':
+ s = e_T.detach()
+ elif s_type == 'deltaL':
+ # delta_L = grad_{h_L} CE (output-layer-local, allowed)
+ hL_req = hL_det.clone().requires_grad_(True)
+ logits_for_s = model.out_head(model.out_ln(hL_req))
+ ce_for_s = F.cross_entropy(logits_for_s, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce_for_s, hL_req, create_graph=False)[0].detach()
+ s = delta_L
+
+ # Terminal boundary
+ t_L = torch.ones(batch, device=device)
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ # Terminal gradient matching
+ loss_tgrad = torch.tensor(0.0, device=device)
+ if term_grad_weight > 0:
+ hL_req = hL_det.clone().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ # Exact terminal gradient (output-layer-local)
+ hL_req2 = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req2))
+ ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum')
+ a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach()
+ loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ # Bridge consistency
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ t_l_next = torch.full((batch,), (l + 1) / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K_samples):
+ noise = sigma_bridge * torch.randn_like(h_next_det)
+ V_next = value_net_ema(h_next_det + noise, t_l_next, s)
+ log_terms.append(-V_next / lam)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge = loss_bridge / L
+
+ value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad
+
+ value_opt.zero_grad()
+ value_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ value_opt.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+
+ total_vloss += value_loss.item() * batch
+ total_term += loss_term.item() * batch
+ total_tgrad += loss_tgrad.item() * batch
+ total_bridge += (loss_bridge.item() if isinstance(loss_bridge, torch.Tensor) else loss_bridge) * batch
+ n += batch
+
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [CB_{s_type}] Epoch {epoch}: vloss={total_vloss/n:.6f}, "
+ f"term={total_term/n:.6f}, tgrad={total_tgrad/n:.6f}, bridge={total_bridge/n:.6f}")
+
+ return value_net, value_net_ema
+
+
+# =============================================================================
+# Step 3: Evaluate credit quality on frozen features
+# =============================================================================
+
+def evaluate_credits(model, test_loader, device, estimators, args):
+ """
+ Evaluate credit quality for all estimators on frozen BP features.
+
+ Args:
+ estimators: dict of {name: {'type': 'sb'/'cb', 'net': ..., 's_type': ...}}
+ Returns:
+ dict of {name: {per-layer metrics}}
+ """
+ model.eval()
+ d = model.d_hidden
+ L = model.num_blocks
+ num_classes = 10
+
+ # Accumulate over multiple test batches for robust statistics
+ all_results = {}
+ for name in estimators:
+ all_results[name] = {
+ 'bp_cosine': [[] for _ in range(L)],
+ 'perturbation_rho': [0.0] * L,
+ 'nudging_0.001': [0.0] * L,
+ 'nudging_0.003': [0.0] * L,
+ 'nudging_0.01': [0.0] * L,
+ }
+
+ # Also add DFA baseline
+ dfa_Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)]
+ all_results['dfa'] = {
+ 'bp_cosine': [[] for _ in range(L)],
+ 'perturbation_rho': [0.0] * L,
+ 'nudging_0.001': [0.0] * L,
+ 'nudging_0.003': [0.0] * L,
+ 'nudging_0.01': [0.0] * L,
+ }
+
+ n_batches_diag = min(10, len(test_loader)) # Use multiple batches
+ batch_idx = 0
+
+ for x, y in test_loader:
+ if batch_idx >= n_batches_diag:
+ break
+ batch_idx += 1
+
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ # Get BP gradients (ground truth for Gamma)
+ # Temporarily enable grad on model params for BP gradient computation
+ for p in model.parameters():
+ p.requires_grad_(True)
+ model.zero_grad()
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ loss_bp = F.cross_entropy(logits_bp, y)
+ loss_bp.backward()
+ bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)}
+ # Re-freeze model
+ for p in model.parameters():
+ p.requires_grad_(False)
+
+ # Clean forward
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s_eT = e_T.detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # Compute delta_L for deltaL conditioning
+ hL_req = hL_det.clone().requires_grad_(True)
+ logits_for_delta = model.out_head(model.out_ln(hL_req))
+ ce_for_delta = F.cross_entropy(logits_for_delta, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce_for_delta, hL_req, create_graph=False)[0].detach()
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ # Forward function for perturbation and nudging
+ def make_fwd_fn(start_l):
+ def fwd_fn(h):
+ with torch.no_grad():
+ curr = h
+ for i in range(start_l, L):
+ curr = curr + model.blocks[i](curr)
+ out = model.out_head(model.out_ln(curr))
+ return F.cross_entropy(out, y, reduction='none')
+ return fwd_fn
+
+ fwd_fn = make_fwd_fn(l)
+
+ # --- DFA credit ---
+ a_dfa = (s_eT @ dfa_Bs[l].T).detach()
+ bp_cos_dfa = cosine_similarity_batch(a_dfa, bp_grads[l])
+ all_results['dfa']['bp_cosine'][l].append(bp_cos_dfa)
+
+ if batch_idx == 1: # Only compute rho/nudging on first batch (expensive)
+ rho_dfa = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=32)
+ all_results['dfa']['perturbation_rho'][l] = rho_dfa
+ for eta in [0.001, 0.003, 0.01]:
+ nud = nudging_test(h_l, a_dfa, fwd_fn, eta=eta)
+ all_results['dfa'][f'nudging_{eta}'][l] = nud
+
+ # --- Estimator credits ---
+ for name, est in estimators.items():
+ if est['type'] == 'sb':
+ net = est['net']
+ net.eval()
+ h_l_req = h_l.clone().requires_grad_(True)
+ pred_hL = net(h_l_req, t_l, s_eT)
+ pred_logits = model.out_head(model.out_ln(pred_hL))
+ pred_loss = F.cross_entropy(pred_logits, y, reduction='sum')
+ a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach()
+
+ elif est['type'] == 'cb':
+ net = est['net']
+ net.eval()
+ s_type = est['s_type']
+ if s_type == 'eT':
+ s = s_eT
+ elif s_type == 'deltaL':
+ s = delta_L
+ else:
+ raise ValueError(f"Unknown s_type: {s_type}")
+
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = net(h_l_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach()
+ else:
+ raise ValueError(f"Unknown estimator type: {est['type']}")
+
+ bp_cos = cosine_similarity_batch(a_l, bp_grads[l])
+ all_results[name]['bp_cosine'][l].append(bp_cos)
+
+ if batch_idx == 1:
+ rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=32)
+ all_results[name]['perturbation_rho'][l] = rho
+ for eta in [0.001, 0.003, 0.01]:
+ nud = nudging_test(h_l, a_l, fwd_fn, eta=eta)
+ all_results[name][f'nudging_{eta}'][l] = nud
+
+ # Average bp_cosine over batches
+ for name in all_results:
+ for l in range(L):
+ vals = all_results[name]['bp_cosine'][l]
+ all_results[name]['bp_cosine'][l] = float(np.mean(vals)) if vals else 0.0
+
+ return all_results
+
+
+def evaluate_state_bridge_pred_error(model, state_pred, test_loader, device):
+ """Evaluate state bridge's terminal state prediction error."""
+ model.eval()
+ state_pred.eval()
+ L = model.num_blocks
+
+ total_error = [0.0] * L
+ n = 0
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ hL = hiddens[-1]
+
+ for l in range(L):
+ h_l = hiddens[l]
+ t_l = torch.full((batch,), l / L, device=x.device)
+ pred_hL = state_pred(h_l, t_l, s)
+ error = ((pred_hL - hL) ** 2).sum(dim=-1).mean().item()
+ total_error[l] += error * batch
+ n += batch
+
+ return [e / n for e in total_error]
+
+
+# =============================================================================
+# Main experiment
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(batch_size=args.batch_size)
+ input_dim = 32 * 32 * 3
+ num_classes = 10
+
+ # ----- Step 1: Train BP reference -----
+ print(f"\n{'='*60}")
+ print(f"Step 1: Train BP reference (L={args.num_blocks}, d={args.d_hidden})")
+ print(f"{'='*60}")
+
+ bp_ckpt_path = os.path.join(args.output_dir, f'bp_ref_L{args.num_blocks}_d{args.d_hidden}_s{args.seed}.pt')
+
+ model = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+
+ if os.path.exists(bp_ckpt_path) and not args.retrain_bp:
+ print(f" Loading BP reference from {bp_ckpt_path}")
+ model.load_state_dict(torch.load(bp_ckpt_path, map_location=device))
+ bp_acc = evaluate(model, test_loader, device)
+ print(f" BP reference test accuracy: {bp_acc:.4f}")
+ else:
+ bp_acc = train_bp_reference(model, train_loader, test_loader, device,
+ epochs=args.bp_epochs, lr=args.lr, wd=args.wd)
+ torch.save(model.state_dict(), bp_ckpt_path)
+ print(f" Saved BP reference to {bp_ckpt_path}")
+
+ # Freeze the model completely
+ model.eval()
+ for p in model.parameters():
+ p.requires_grad_(False)
+
+ # ----- Step 2: Train estimators -----
+ print(f"\n{'='*60}")
+ print(f"Step 2: Train estimators ({args.estimator_epochs} epochs each)")
+ print(f"{'='*60}")
+
+ estimators = {}
+
+ # 2a. State Bridge with s=eT
+ print("\n--- State Bridge (s=eT) ---")
+ torch.manual_seed(args.seed + 1000)
+ sb = train_state_bridge_frozen(model, train_loader, device, args)
+ estimators['sb_eT'] = {'type': 'sb', 'net': sb, 's_type': 'eT'}
+
+ # 2b. Scalar CB with s=eT
+ print("\n--- Scalar CB (s=eT) ---")
+ torch.manual_seed(args.seed + 2000)
+ cb_eT, cb_eT_ema = train_scalar_cb_frozen(model, train_loader, device, args, s_type='eT')
+ estimators['cb_eT'] = {'type': 'cb', 'net': cb_eT, 's_type': 'eT'}
+
+ # 2c. Scalar CB with s=deltaL
+ print("\n--- Scalar CB (s=deltaL) ---")
+ torch.manual_seed(args.seed + 3000)
+ cb_dL, cb_dL_ema = train_scalar_cb_frozen(model, train_loader, device, args, s_type='deltaL')
+ estimators['cb_deltaL'] = {'type': 'cb', 'net': cb_dL, 's_type': 'deltaL'}
+
+ # ----- Step 3: Evaluate -----
+ print(f"\n{'='*60}")
+ print(f"Step 3: Evaluate credit quality")
+ print(f"{'='*60}")
+
+ results = evaluate_credits(model, test_loader, device, estimators, args)
+
+ # State bridge prediction error
+ sb_pred_error = evaluate_state_bridge_pred_error(model, sb, test_loader, device)
+
+ # ----- Print results -----
+ L = args.num_blocks
+ print(f"\n{'='*60}")
+ print(f"RESULTS: Frozen CIFAR Credit Recovery (L={L}, d={args.d_hidden}, seed={args.seed})")
+ print(f"BP reference test accuracy: {bp_acc:.4f}")
+ print(f"{'='*60}")
+
+ # Summary table
+ methods = ['dfa', 'sb_eT', 'cb_eT', 'cb_deltaL']
+ method_labels = {
+ 'dfa': 'DFA (random)',
+ 'sb_eT': 'State Bridge (eT)',
+ 'cb_eT': 'Scalar CB (eT)',
+ 'cb_deltaL': 'Scalar CB (deltaL)',
+ }
+
+ print(f"\n{'Method':<25} {'mean Gamma':>12} {'mean rho':>12} {'mean nudge':>12}")
+ print("-" * 65)
+
+ summary = {}
+ for m in methods:
+ r = results[m]
+ mean_gamma = np.mean(r['bp_cosine'])
+ mean_rho = np.mean(r['perturbation_rho'])
+ mean_nudge = np.mean(r['nudging_0.003'])
+ summary[m] = {
+ 'mean_gamma': float(mean_gamma),
+ 'mean_rho': float(mean_rho),
+ 'mean_nudge': float(mean_nudge),
+ }
+ print(f"{method_labels[m]:<25} {mean_gamma:>12.4f} {mean_rho:>12.4f} {mean_nudge:>12.6f}")
+
+ # Per-layer detail
+ print(f"\n--- Per-layer Gamma ---")
+ header = f"{'Layer':<8}"
+ for m in methods:
+ header += f" {method_labels[m]:>16}"
+ print(header)
+ for l in range(L):
+ row = f" {l:<6}"
+ for m in methods:
+ row += f" {results[m]['bp_cosine'][l]:>16.4f}"
+ print(row)
+
+ print(f"\n--- Per-layer rho ---")
+ print(header)
+ for l in range(L):
+ row = f" {l:<6}"
+ for m in methods:
+ row += f" {results[m]['perturbation_rho'][l]:>16.4f}"
+ print(row)
+
+ print(f"\n--- Per-layer nudge (eta=0.003) ---")
+ print(header)
+ for l in range(L):
+ row = f" {l:<6}"
+ for m in methods:
+ row += f" {results[m]['nudging_0.003'][l]:>16.6f}"
+ print(row)
+
+ print(f"\n--- State Bridge prediction error per layer ---")
+ for l in range(L):
+ print(f" Layer {l}: {sb_pred_error[l]:.6f}")
+
+ # ----- Save all results -----
+ save_data = {
+ 'config': {
+ 'num_blocks': args.num_blocks,
+ 'd_hidden': args.d_hidden,
+ 'seed': args.seed,
+ 'bp_epochs': args.bp_epochs,
+ 'estimator_epochs': args.estimator_epochs,
+ 'lr_fb': args.lr_fb,
+ 'lam': args.lam,
+ 'K': args.K,
+ 'sigma_bridge': args.sigma_bridge,
+ 'ema_momentum': args.ema_momentum,
+ 'term_grad_weight': args.term_grad_weight,
+ },
+ 'bp_acc': float(bp_acc),
+ 'summary': summary,
+ 'per_layer': {},
+ 'sb_pred_error': sb_pred_error,
+ }
+
+ for m in methods:
+ save_data['per_layer'][m] = {
+ 'bp_cosine': results[m]['bp_cosine'],
+ 'perturbation_rho': results[m]['perturbation_rho'],
+ 'nudging_0.001': results[m]['nudging_0.001'],
+ 'nudging_0.003': results[m]['nudging_0.003'],
+ 'nudging_0.01': results[m]['nudging_0.01'],
+ }
+
+ out_path = os.path.join(args.output_dir,
+ f'frozen_L{args.num_blocks}_d{args.d_hidden}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(save_data, f, indent=2)
+ print(f"\nResults saved to {out_path}")
+
+ # ----- Judgment -----
+ print(f"\n{'='*60}")
+ print("JUDGMENT")
+ print(f"{'='*60}")
+
+ best_cb = max(summary['cb_eT']['mean_rho'], summary['cb_deltaL']['mean_rho'])
+ dfa_rho = summary['dfa']['mean_rho']
+ best_cb_gamma = max(summary['cb_eT']['mean_gamma'], summary['cb_deltaL']['mean_gamma'])
+ dfa_gamma = summary['dfa']['mean_gamma']
+
+ if best_cb > dfa_rho + 0.02 and best_cb_gamma > dfa_gamma:
+ print("POSITIVE: Scalar CB recovers credit that is clearly better than DFA.")
+ print(" -> Bottleneck is in forward exploitability / local update, not estimator.")
+ print(" -> Next: Phase B (online shallow CIFAR).")
+ elif best_cb > 0.02:
+ print("MARGINAL: Scalar CB shows some signal but not clearly better than DFA.")
+ print(" -> Need more investigation before concluding estimator is the bottleneck.")
+ else:
+ print("NEGATIVE: Scalar CB cannot recover useful credit even on frozen features.")
+ print(" -> Estimator parameterization is the bottleneck.")
+ print(" -> Next: Phase C (direct vector field pilot).")
+
+ return save_data
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Frozen CIFAR Credit Recovery')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--bp_epochs', type=int, default=100,
+ help='Epochs to train BP reference')
+ parser.add_argument('--estimator_epochs', type=int, default=100,
+ help='Epochs to train each estimator on frozen features')
+ parser.add_argument('--lr', type=float, default=1e-3, help='LR for BP reference')
+ parser.add_argument('--lr_fb', type=float, default=1e-3, help='LR for estimators')
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=4)
+ parser.add_argument('--sigma_bridge', type=float, default=0.05)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--term_grad_weight', type=float, default=1.0)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=2)
+ parser.add_argument('--output_dir', type=str, default='results/frozen_cifar')
+ parser.add_argument('--retrain_bp', action='store_true')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()